Adjusted Callable type hint
This commit is contained in:
parent
ed645c2fbe
commit
5ebd4225cc
@ -1,4 +1,4 @@
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, Callable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -19,8 +19,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
duration: float,
|
||||
verbose: int = 1,
|
||||
learn_sub_trajectories: bool = False,
|
||||
replanning_schedule: Optional[callable] = None,
|
||||
reward_aggregation: callable = np.sum
|
||||
replanning_schedule: Optional[
|
||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum
|
||||
):
|
||||
"""
|
||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||
|
@ -1,5 +1,5 @@
|
||||
from itertools import chain
|
||||
from typing import Tuple, Type, Union, Optional
|
||||
from typing import Tuple, Type, Union, Optional, Callable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -123,7 +123,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||
|
||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
|
||||
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
|
||||
def test_aggregation(mp_type: str, reward_aggregation: callable):
|
||||
def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]):
|
||||
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
|
||||
{'trajectory_generator_type': mp_type},
|
||||
{'controller_type': 'motor'},
|
||||
@ -327,4 +327,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
||||
active_vel = vel[delay_time_steps: joint_time_steps - 2]
|
||||
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
|
||||
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from itertools import chain
|
||||
from types import FunctionType
|
||||
from typing import Tuple, Type, Union, Optional
|
||||
|
||||
import gym
|
||||
@ -120,7 +121,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
||||
{'basis_generator_type': 'rbf'}, seed=SEED)
|
||||
|
||||
assert env.do_replanning
|
||||
assert env.replanning_schedule
|
||||
assert callable(env.replanning_schedule)
|
||||
# This also verifies we are not adding the TimeAwareObservationWrapper twice
|
||||
assert env.observation_space == env_step.observation_space
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user