From f07b8a26ac2f93a37d355fb2b8ff136dd6b6f21d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 10 Jun 2023 18:49:02 +0200 Subject: [PATCH] Made some assertions more verbose for easier debugging --- test/test_black_box.py | 4 ++-- test/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_black_box.py b/test/test_black_box.py index 74985ac..bfde2fb 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -124,12 +124,12 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): {'phase_generator_type': 'exp'}, {'basis_generator_type': basis_generator_type}) - for _ in range(5): + for i in range(5): env.reset() _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) length = info['trajectory_length'] - assert length == env.spec.max_episode_steps + assert length == env.spec.max_episode_steps, f'Expcted total simulation length ({length}) to be equal to spec.max_episode_steps ({env.spec.max_episode_steps}), but was not during test nr. {i}' @pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp']) diff --git a/test/utils.py b/test/utils.py index 2402f98..86e82a2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -76,7 +76,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers for i, time_step in enumerate(zip(*traj1, *traj2)): obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step assert np.allclose( - obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match." + obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match: Biggest difference is {np.abs(obs1-obs2).max()} at index {np.abs(obs1-obs2).argmax()}." assert np.array_equal( ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match." assert np.array_equal( @@ -89,7 +89,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"): assert observation_space.contains(obs), \ - f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}." + f"Observation {obs} ({obs.shape}) received from {obs_type} not contained in observation space {observation_space}." def verify_reward(reward):