From f44f01b478c7f21c4bf555308679c5f1263ead78 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 18 Jun 2023 11:52:35 +0200 Subject: [PATCH] Fix: Allow observation space dict in test_replanning --- test/test_replanning_sequencing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index 6e4a760..6425ef8 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -80,7 +80,10 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter assert env.learn_sub_trajectories assert env.traj_gen.learn_tau # This also verifies we are not adding the TimeAwareObservationWrapper twice - assert env.observation_space == env_step.observation_space + if env.observation_space.__class__ in [spaces.Dict]: + assert spaces.flatten_space(env.observation_space) == env_step.observation_space + else: + assert env.observation_space == env_step.observation_space done = True @@ -130,7 +133,10 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra assert env.do_replanning assert callable(env.replanning_schedule) # This also verifies we are not adding the TimeAwareObservationWrapper twice - assert env.observation_space == env_step.observation_space + if env.observation_space.__class__ in [spaces.Dict]: + assert spaces.flatten_space(env.observation_space) == env_step.observation_space + else: + assert env.observation_space == env_step.observation_space env.reset(seed=SEED)