fancy_gym/test/test_replanning_envs.py

82 lines
2.7 KiB
Python
Raw Normal View History

from itertools import chain
2022-11-09 17:54:34 +01:00
from typing import Tuple, Type, Union, Optional, Callable
2022-11-09 17:54:34 +01:00
import gym
import numpy as np
import pytest
2022-11-09 17:54:34 +01:00
from gym import register
from gym.core import ActType, ObsType
import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
import fancy_gym
from test.utils import run_env, run_env_determinism
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
2022-10-26 15:18:37 +02:00
2022-11-09 17:54:34 +01:00
class Object(object):
pass
class ToyEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
dt = 0.02
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
return np.array([-1])
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
return np.array([-1]), 1, False, {}
def render(self, mode="human"):
pass
class ToyWrapper(RawInterfaceWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return np.ones(self.action_space.shape)
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return np.zeros(self.action_space.shape)
2022-10-26 15:18:37 +02:00
2022-11-09 17:54:34 +01:00
@pytest.fixture(scope="session", autouse=True)
def setup():
register(
id=f'toy-v0',
entry_point='test.test_black_box:ToyEnv',
max_episode_steps=50,
)
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_envs(env_id: str):
# """Tests that ProDMP environments run without errors using random actions."""
# run_env(env_id)
#
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_determinism(env_id: str):
# """Tests that ProDMP environments are deterministic."""
# run_env_determinism(env_id, 0)
2022-10-26 15:18:37 +02:00
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
def test_missing_local_state(mp_type: str):
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
2022-10-26 15:18:37 +02:00
2022-11-09 17:54:34 +01:00
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
{'basis_generator_type': basis_generator_type})
env.reset()
with pytest.raises(NotImplementedError):
env.step(env.action_space.sample())