test_black_box.py should use vanilla env.make

This commit is contained in:
Dominik Moritz Roth 2023-07-23 12:20:49 +02:00
parent b0f7dc6c7c
commit eb9b6e1e22

View File

@ -4,7 +4,7 @@ from typing import Tuple, Type, Union, Optional, Callable
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import register
from gymnasium import register, make
from gymnasium.core import ActType, ObsType
import fancy_gym
@ -13,7 +13,7 @@ from fancy_gym.utils.wrappers import TimeAwareObservation
from test.utils import ugly_hack_to_mitigate_metaworld_bug
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
ENV_IDS = ['Reacher5d-v0', 'dm_control/ball_in_cup-catch-v0', 'metaworld/reach-v2', 'Reacher-v2']
WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper,
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
@ -102,7 +102,7 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
info_keys = list(info.keys())
env_step = fancy_gym.make(env_id, SEED)
env_step = make(env_id)
env_step.reset()
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
info_keys_step = info.keys()
@ -161,7 +161,7 @@ def test_context_space(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapp
{'phase_generator_type': 'exp'},
{'basis_generator_type': 'rbf'})
# check if observation space matches with the specified mask values which are true
env_step = fancy_gym.make(env_id, SEED)
env_step = make(env_id)
wrapper = wrapper_class(env_step)
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
@ -231,8 +231,9 @@ def test_learn_tau(mp_type: str, tau: float):
'learn_delay': False
},
{'basis_generator_type': basis_generator_type,
}, seed=SEED)
})
env.reset(seed=SEED)
done = True
for i in range(5):
if done:
@ -277,8 +278,9 @@ def test_learn_delay(mp_type: str, delay: float):
'learn_delay': True
},
{'basis_generator_type': basis_generator_type,
}, seed=SEED)
})
env.reset(seed=SEED)
done = True
for i in range(5):
if done:
@ -323,7 +325,9 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
'learn_delay': True
},
{'basis_generator_type': basis_generator_type,
}, seed=SEED)
})
env.reset(seed=SEED)
if env.spec.max_episode_steps * env.dt < delay + tau:
return