2022-10-25 20:10:59 +02:00
|
|
|
from itertools import chain
|
2022-11-09 17:54:34 +01:00
|
|
|
from typing import Tuple, Type, Union, Optional, Callable
|
2022-10-25 20:10:59 +02:00
|
|
|
|
2022-11-09 17:54:34 +01:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
2022-10-25 20:10:59 +02:00
|
|
|
import pytest
|
2022-11-09 17:54:34 +01:00
|
|
|
from gym import register
|
|
|
|
from gym.core import ActType, ObsType
|
|
|
|
|
|
|
|
import fancy_gym
|
|
|
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
2022-10-25 20:10:59 +02:00
|
|
|
|
|
|
|
import fancy_gym
|
|
|
|
from test.utils import run_env, run_env_determinism
|
|
|
|
|
|
|
|
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|
|
|
|
|
|
|
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|
|
|
|
2022-10-26 15:18:37 +02:00
|
|
|
|
2022-11-09 17:54:34 +01:00
|
|
|
class Object(object):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ToyEnv(gym.Env):
|
|
|
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
|
|
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
|
|
|
dt = 0.02
|
|
|
|
|
|
|
|
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
|
|
|
|
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
|
|
|
|
|
|
|
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
|
|
|
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
|
|
|
|
return np.array([-1])
|
|
|
|
|
|
|
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
|
|
|
return np.array([-1]), 1, False, {}
|
|
|
|
|
|
|
|
def render(self, mode="human"):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ToyWrapper(RawInterfaceWrapper):
|
|
|
|
|
|
|
|
@property
|
|
|
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
|
|
return np.ones(self.action_space.shape)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
|
|
return np.zeros(self.action_space.shape)
|
2022-10-26 15:18:37 +02:00
|
|
|
|
2022-11-09 17:54:34 +01:00
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
|
|
def setup():
|
|
|
|
register(
|
|
|
|
id=f'toy-v0',
|
|
|
|
entry_point='test.test_black_box:ToyEnv',
|
|
|
|
max_episode_steps=50,
|
|
|
|
)
|
|
|
|
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
|
|
|
# def test_replanning_envs(env_id: str):
|
|
|
|
# """Tests that ProDMP environments run without errors using random actions."""
|
|
|
|
# run_env(env_id)
|
|
|
|
#
|
|
|
|
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
|
|
|
# def test_replanning_determinism(env_id: str):
|
|
|
|
# """Tests that ProDMP environments are deterministic."""
|
|
|
|
# run_env_determinism(env_id, 0)
|
2022-10-26 15:18:37 +02:00
|
|
|
|
2022-11-09 17:54:34 +01:00
|
|
|
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
|
|
|
def test_missing_local_state(mp_type: str):
|
|
|
|
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
2022-10-26 15:18:37 +02:00
|
|
|
|
2022-11-09 17:54:34 +01:00
|
|
|
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
|
|
|
|
{'trajectory_generator_type': mp_type},
|
|
|
|
{'controller_type': 'motor'},
|
|
|
|
{'phase_generator_type': 'exp'},
|
|
|
|
{'basis_generator_type': basis_generator_type})
|
|
|
|
env.reset()
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
env.step(env.action_space.sample())
|