Made some assertions more verbose for easier debugging

This commit is contained in:
Dominik Moritz Roth 2023-06-10 18:49:02 +02:00
parent 40d2409c26
commit f07b8a26ac
2 changed files with 4 additions and 4 deletions

View File

@ -124,12 +124,12 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
{'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type})
for _ in range(5):
for i in range(5):
env.reset()
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
length = info['trajectory_length']
assert length == env.spec.max_episode_steps
assert length == env.spec.max_episode_steps, f'Expcted total simulation length ({length}) to be equal to spec.max_episode_steps ({env.spec.max_episode_steps}), but was not during test nr. {i}'
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])

View File

@ -76,7 +76,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
for i, time_step in enumerate(zip(*traj1, *traj2)):
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
assert np.allclose(
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match."
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match: Biggest difference is {np.abs(obs1-obs2).max()} at index {np.abs(obs1-obs2).argmax()}."
assert np.array_equal(
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
assert np.array_equal(
@ -89,7 +89,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
assert observation_space.contains(obs), \
f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}."
f"Observation {obs} ({obs.shape}) received from {obs_type} not contained in observation space {observation_space}."
def verify_reward(reward):