Fix: Need to supply seed to reset in tests

This commit is contained in:
Dominik Moritz Roth 2023-06-18 11:51:01 +02:00
parent 9605f2e56c
commit fbba129034
3 changed files with 17 additions and 16 deletions

View File

@ -78,7 +78,7 @@ def test_missing_local_state(mp_type: str):
{'controller_type': 'motor'}, {'controller_type': 'motor'},
{'phase_generator_type': 'exp'}, {'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type}) {'basis_generator_type': basis_generator_type})
env.reset() env.reset(seed=SEED)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
env.step(env.action_space.sample()) env.step(env.action_space.sample())
@ -95,7 +95,7 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
{'controller_type': 'motor'}, {'controller_type': 'motor'},
{'phase_generator_type': 'exp'}, {'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type}) {'basis_generator_type': basis_generator_type})
env.reset() env.reset(seed=SEED)
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
info_keys = list(info.keys()) info_keys = list(info.keys())
@ -125,7 +125,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
{'basis_generator_type': basis_generator_type}) {'basis_generator_type': basis_generator_type})
for i in range(5): for i in range(5):
env.reset() env.reset(seed=SEED)
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
length = info['trajectory_length'] length = info['trajectory_length']
@ -141,7 +141,7 @@ def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], fl
{'controller_type': 'motor'}, {'controller_type': 'motor'},
{'phase_generator_type': 'exp'}, {'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type}) {'basis_generator_type': basis_generator_type})
env.reset() env.reset(seed=SEED)
# ToyEnv only returns 1 as reward # ToyEnv only returns 1 as reward
_obs, reward, _terminated, _truncated, _info = env.step(env.action_space.sample()) _obs, reward, _terminated, _truncated, _info = env.step(env.action_space.sample())
assert reward == reward_aggregation(np.ones(50, )) assert reward == reward_aggregation(np.ones(50, ))
@ -232,7 +232,7 @@ def test_learn_tau(mp_type: str, tau: float):
done = True done = True
for i in range(5): for i in range(5):
if done: if done:
env.reset() env.reset(seed=SEED)
action = env.action_space.sample() action = env.action_space.sample()
action[0] = tau action[0] = tau
@ -278,7 +278,7 @@ def test_learn_delay(mp_type: str, delay: float):
done = True done = True
for i in range(5): for i in range(5):
if done: if done:
env.reset() env.reset(seed=SEED)
action = env.action_space.sample() action = env.action_space.sample()
action[0] = delay action[0] = delay
@ -327,7 +327,7 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
done = True done = True
for i in range(5): for i in range(5):
if done: if done:
env.reset() env.reset(seed=SEED)
action = env.action_space.sample() action = env.action_space.sample()
action[0] = tau action[0] = tau
action[1] = delay action[1] = delay

View File

@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from gymnasium import register from gymnasium import register
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
from gymnasium import spaces
import fancy_gym import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
@ -85,7 +86,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
for i in range(25): for i in range(25):
if done: if done:
env.reset() env.reset(seed=SEED)
action = env.action_space.sample() action = env.action_space.sample()
_obs, _reward, terminated, truncated, info = env.step(action) _obs, _reward, terminated, truncated, info = env.step(action)
done = terminated or truncated done = terminated or truncated
@ -131,7 +132,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
# This also verifies we are not adding the TimeAwareObservationWrapper twice # This also verifies we are not adding the TimeAwareObservationWrapper twice
assert env.observation_space == env_step.observation_space assert env.observation_space == env_step.observation_space
env.reset() env.reset(seed=SEED)
episode_steps = env_step.spec.max_episode_steps // replanning_time episode_steps = env_step.spec.max_episode_steps // replanning_time
# Make 3 episodes, total steps depend on the replanning steps # Make 3 episodes, total steps depend on the replanning steps
@ -146,7 +147,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
# Check if number of steps until termination match the replanning interval # Check if number of steps until termination match the replanning interval
print(done, (i + 1), episode_steps) print(done, (i + 1), episode_steps)
assert (i + 1) % episode_steps == 0 assert (i + 1) % episode_steps == 0
env.reset() env.reset(seed=SEED)
assert replanning_schedule(None, None, None, None, length) assert replanning_schedule(None, None, None, None, length)
@ -171,7 +172,7 @@ def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_s
{'basis_generator_type': basis_generator_type, {'basis_generator_type': basis_generator_type,
}, },
seed=SEED) seed=SEED)
_ = env.reset() _ = env.reset(seed=SEED)
done = False done = False
planning_times = 0 planning_times = 0
while not done: while not done:
@ -203,7 +204,7 @@ def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_se
{'basis_generator_type': basis_generator_type, {'basis_generator_type': basis_generator_type,
}, },
seed=SEED) seed=SEED)
_ = env.reset() _ = env.reset(seed=SEED)
done = False done = False
planning_times = 0 planning_times = 0
while not done: while not done:
@ -236,7 +237,7 @@ def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_
{'basis_generator_type': basis_generator_type, {'basis_generator_type': basis_generator_type,
}, },
seed=SEED) seed=SEED)
_ = env.reset() _ = env.reset(seed=SEED)
done = False done = False
planning_times = 0 planning_times = 0
while not done: while not done:
@ -291,7 +292,7 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i
{'basis_generator_type': basis_generator_type, {'basis_generator_type': basis_generator_type,
}, },
seed=SEED) seed=SEED)
_ = env.reset() _ = env.reset(seed=SEED)
done = False done = False
planning_times = 0 planning_times = 0
while not done: while not done:
@ -340,7 +341,7 @@ def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_
{'basis_generator_type': basis_generator_type, {'basis_generator_type': basis_generator_type,
}, },
seed=SEED) seed=SEED)
_ = env.reset() _ = env.reset(seed=SEED)
for i in range(max_planning_times): for i in range(max_planning_times):
action = env.action_space.sample() action = env.action_space.sample()
_obs, _reward, terminated, truncated, _info = env.step(action) _obs, _reward, terminated, truncated, _info = env.step(action)

View File

@ -30,7 +30,7 @@ def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[T
actions = [] actions = []
terminations = [] terminations = []
truncations = [] truncations = []
obs, _ = env.reset() obs, _ = env.reset(seed=seed)
verify_observations(obs, env.observation_space, "reset()") verify_observations(obs, env.observation_space, "reset()")
iterations = iterations or (env.spec.max_episode_steps or 1) iterations = iterations or (env.spec.max_episode_steps or 1)