diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index 6e4a760..6425ef8 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -80,7 +80,10 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter assert env.learn_sub_trajectories assert env.traj_gen.learn_tau # This also verifies we are not adding the TimeAwareObservationWrapper twice - assert env.observation_space == env_step.observation_space + if env.observation_space.__class__ in [spaces.Dict]: + assert spaces.flatten_space(env.observation_space) == env_step.observation_space + else: + assert env.observation_space == env_step.observation_space done = True @@ -130,7 +133,10 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra assert env.do_replanning assert callable(env.replanning_schedule) # This also verifies we are not adding the TimeAwareObservationWrapper twice - assert env.observation_space == env_step.observation_space + if env.observation_space.__class__ in [spaces.Dict]: + assert spaces.flatten_space(env.observation_space) == env_step.observation_space + else: + assert env.observation_space == env_step.observation_space env.reset(seed=SEED)