Made some assertions more verbose for easier debugging
This commit is contained in:
parent
40d2409c26
commit
f07b8a26ac
@ -124,12 +124,12 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
|||||||
{'phase_generator_type': 'exp'},
|
{'phase_generator_type': 'exp'},
|
||||||
{'basis_generator_type': basis_generator_type})
|
{'basis_generator_type': basis_generator_type})
|
||||||
|
|
||||||
for _ in range(5):
|
for i in range(5):
|
||||||
env.reset()
|
env.reset()
|
||||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||||
length = info['trajectory_length']
|
length = info['trajectory_length']
|
||||||
|
|
||||||
assert length == env.spec.max_episode_steps
|
assert length == env.spec.max_episode_steps, f'Expcted total simulation length ({length}) to be equal to spec.max_episode_steps ({env.spec.max_episode_steps}), but was not during test nr. {i}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
||||||
|
@ -76,7 +76,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
|
|||||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match."
|
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match: Biggest difference is {np.abs(obs1-obs2).max()} at index {np.abs(obs1-obs2).argmax()}."
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
@ -89,7 +89,7 @@ def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers
|
|||||||
|
|
||||||
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
||||||
assert observation_space.contains(obs), \
|
assert observation_space.contains(obs), \
|
||||||
f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}."
|
f"Observation {obs} ({obs.shape}) received from {obs_type} not contained in observation space {observation_space}."
|
||||||
|
|
||||||
|
|
||||||
def verify_reward(reward):
|
def verify_reward(reward):
|
||||||
|
Loading…
Reference in New Issue
Block a user