Made some assertions more verbose for easier debugging
This commit is contained in:
parent
40d2409c26
commit
f07b8a26ac
@ -124,12 +124,12 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||
{'phase_generator_type': 'exp'},
|
||||
{'basis_generator_type': basis_generator_type})
|
||||
|
||||
for _ in range(5):
|
||||
for i in range(5):
|
||||
env.reset()
|
||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||
length = info['trajectory_length']
|
||||
|
||||
assert length == env.spec.max_episode_steps
|
||||
assert length == env.spec.max_episode_steps, f'Expcted total simulation length ({length}) to be equal to spec.max_episode_steps ({env.spec.max_episode_steps}), but was not during test nr. {i}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
||||
|
@ -76,7 +76,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
|
||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
||||
assert np.allclose(
|
||||
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match."
|
||||
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match: Biggest difference is {np.abs(obs1-obs2).max()} at index {np.abs(obs1-obs2).argmax()}."
|
||||
assert np.array_equal(
|
||||
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||
assert np.array_equal(
|
||||
@ -89,7 +89,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
|
||||
|
||||
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
||||
assert observation_space.contains(obs), \
|
||||
f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}."
|
||||
f"Observation {obs} ({obs.shape}) received from {obs_type} not contained in observation space {observation_space}."
|
||||
|
||||
|
||||
def verify_reward(reward):
|
||||
|
Loading…
Reference in New Issue
Block a user