fancy_gym/test/test_replanning_envs.py
2022-11-09 17:54:34 +01:00

82 lines
2.7 KiB
Python

from itertools import chain
from typing import Tuple, Type, Union, Optional, Callable
import gym
import numpy as np
import pytest
from gym import register
from gym.core import ActType, ObsType
import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
import fancy_gym
from test.utils import run_env, run_env_determinism
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
class Object(object):
pass
class ToyEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
dt = 0.02
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
return np.array([-1])
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
return np.array([-1]), 1, False, {}
def render(self, mode="human"):
pass
class ToyWrapper(RawInterfaceWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return np.ones(self.action_space.shape)
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return np.zeros(self.action_space.shape)
@pytest.fixture(scope="session", autouse=True)
def setup():
register(
id=f'toy-v0',
entry_point='test.test_black_box:ToyEnv',
max_episode_steps=50,
)
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_envs(env_id: str):
# """Tests that ProDMP environments run without errors using random actions."""
# run_env(env_id)
#
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_determinism(env_id: str):
# """Tests that ProDMP environments are deterministic."""
# run_env_determinism(env_id, 0)
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
def test_missing_local_state(mp_type: str):
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type})
env.reset()
with pytest.raises(NotImplementedError):
env.step(env.action_space.sample())