test_black_box.py should use vanilla env.make
This commit is contained in:
parent
b0f7dc6c7c
commit
eb9b6e1e22
@ -4,7 +4,7 @@ from typing import Tuple, Type, Union, Optional, Callable
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gymnasium import register
|
from gymnasium import register, make
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
|
|
||||||
import fancy_gym
|
import fancy_gym
|
||||||
@ -13,7 +13,7 @@ from fancy_gym.utils.wrappers import TimeAwareObservation
|
|||||||
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
||||||
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
ENV_IDS = ['Reacher5d-v0', 'dm_control/ball_in_cup-catch-v0', 'metaworld/reach-v2', 'Reacher-v2']
|
||||||
WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper,
|
WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper,
|
||||||
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||||
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
@ -102,7 +102,7 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
|
|||||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||||
info_keys = list(info.keys())
|
info_keys = list(info.keys())
|
||||||
|
|
||||||
env_step = fancy_gym.make(env_id, SEED)
|
env_step = make(env_id)
|
||||||
env_step.reset()
|
env_step.reset()
|
||||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||||
info_keys_step = info.keys()
|
info_keys_step = info.keys()
|
||||||
@ -161,7 +161,7 @@ def test_context_space(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapp
|
|||||||
{'phase_generator_type': 'exp'},
|
{'phase_generator_type': 'exp'},
|
||||||
{'basis_generator_type': 'rbf'})
|
{'basis_generator_type': 'rbf'})
|
||||||
# check if observation space matches with the specified mask values which are true
|
# check if observation space matches with the specified mask values which are true
|
||||||
env_step = fancy_gym.make(env_id, SEED)
|
env_step = make(env_id)
|
||||||
wrapper = wrapper_class(env_step)
|
wrapper = wrapper_class(env_step)
|
||||||
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
|
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
|
||||||
|
|
||||||
@ -231,8 +231,9 @@ def test_learn_tau(mp_type: str, tau: float):
|
|||||||
'learn_delay': False
|
'learn_delay': False
|
||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
}, seed=SEED)
|
})
|
||||||
|
|
||||||
|
env.reset(seed=SEED)
|
||||||
done = True
|
done = True
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
if done:
|
if done:
|
||||||
@ -277,8 +278,9 @@ def test_learn_delay(mp_type: str, delay: float):
|
|||||||
'learn_delay': True
|
'learn_delay': True
|
||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
}, seed=SEED)
|
})
|
||||||
|
|
||||||
|
env.reset(seed=SEED)
|
||||||
done = True
|
done = True
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
if done:
|
if done:
|
||||||
@ -323,7 +325,9 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
|||||||
'learn_delay': True
|
'learn_delay': True
|
||||||
},
|
},
|
||||||
{'basis_generator_type': basis_generator_type,
|
{'basis_generator_type': basis_generator_type,
|
||||||
}, seed=SEED)
|
})
|
||||||
|
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
|
||||||
if env.spec.max_episode_steps * env.dt < delay + tau:
|
if env.spec.max_episode_steps * env.dt < delay + tau:
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user