Adjusted Callable type hint
This commit is contained in:
parent
ed645c2fbe
commit
5ebd4225cc
@ -1,4 +1,4 @@
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, Callable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -19,8 +19,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
duration: float,
|
duration: float,
|
||||||
verbose: int = 1,
|
verbose: int = 1,
|
||||||
learn_sub_trajectories: bool = False,
|
learn_sub_trajectories: bool = False,
|
||||||
replanning_schedule: Optional[callable] = None,
|
replanning_schedule: Optional[
|
||||||
reward_aggregation: callable = np.sum
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Tuple, Type, Union, Optional
|
from typing import Tuple, Type, Union, Optional, Callable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -123,7 +123,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
|
@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
|
||||||
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
|
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
|
||||||
def test_aggregation(mp_type: str, reward_aggregation: callable):
|
def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]):
|
||||||
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
|
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
|
||||||
{'trajectory_generator_type': mp_type},
|
{'trajectory_generator_type': mp_type},
|
||||||
{'controller_type': 'motor'},
|
{'controller_type': 'motor'},
|
||||||
@ -327,4 +327,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
|||||||
active_vel = vel[delay_time_steps: joint_time_steps - 2]
|
active_vel = vel[delay_time_steps: joint_time_steps - 2]
|
||||||
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
|
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
|
||||||
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
|
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from types import FunctionType
|
||||||
from typing import Tuple, Type, Union, Optional
|
from typing import Tuple, Type, Union, Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -120,7 +121,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
|||||||
{'basis_generator_type': 'rbf'}, seed=SEED)
|
{'basis_generator_type': 'rbf'}, seed=SEED)
|
||||||
|
|
||||||
assert env.do_replanning
|
assert env.do_replanning
|
||||||
assert env.replanning_schedule
|
assert callable(env.replanning_schedule)
|
||||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||||
assert env.observation_space == env_step.observation_space
|
assert env.observation_space == env_step.observation_space
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user