fancy_gym/test/test_black_box.py
2022-09-26 16:11:41 +02:00

105 lines
3.8 KiB
Python

from itertools import chain
from typing import Tuple, Type, Union, Optional
import gym
import numpy as np
import pytest
from gym import register
from gym.wrappers import TimeLimit
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from gym.core import ActType, ObsType
import fancy_gym
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch', 'metaworld:reach-v2', 'Reacher-v2']
WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper,
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
class ToyEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
dt = 0.01
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
return np.array([-1])
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
return np.array([-1]), 1, False, {}
def render(self, mode="human"):
pass
class ToyWrapper(RawInterfaceWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return np.ones(self.action_space.shape)
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return np.ones(self.action_space.shape)
@pytest.fixture(scope="session", autouse=True)
def setup():
register(
id=f'toy-v0',
entry_point='test.test_black_box:ToyEnv',
max_episode_steps=50,
)
@pytest.mark.parametrize('env_id', ENV_IDS)
def test_missing_wrapper(env_id: str):
with pytest.raises(ValueError):
fancy_gym.make_bb(env_id, [], {}, {}, {}, {}, {})
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
def test_context_space(env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
env_id, wrapper_class = env_wrap
env = fancy_gym.make_bb(env_id, [wrapper_class], {},
{'trajectory_generator_type': 'promp'},
{'controller_type': 'motor'},
{'phase_generator_type': 'linear'},
{'basis_generator_type': 'rbf'})
# check if observation space matches with the specified mask values which are true
env_step = fancy_gym.make(env_id, SEED)
wrapper = wrapper_class(env_step)
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
@pytest.mark.parametrize('env_id', ENV_IDS)
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
def test_aggregation(env_id: str, reward_aggregation: callable):
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
{'trajectory_generator_type': 'promp'},
{'controller_type': 'motor'},
{'phase_generator_type': 'linear'},
{'basis_generator_type': 'rbf'})
env.reset()
# ToyEnv only returns 1 as reward
assert env.step(env.action_space.sample())[1] == reward_aggregation(np.ones(50, ))
@pytest.mark.parametrize('env_id', ENV_IDS)
@pytest.mark.parametrize('add_time_aware_wrapper_before', [True, False])
def test_learn_sub_trajectories(env_id: str, add_time_aware_wrapper_before: bool):
env_step = fancy_gym.make(env_id, SEED)
env = fancy_gym.make_bb(env_id, [], {}, {}, {}, {'phase_generator_type': 'linear'}, {})
# has time aware wrapper
if add_time_aware_wrapper_before:
pass
assert env.learn_sub_trajectories
assert env.learn_tau
assert env.observation_space == env_step.observation_space