Adjusted Callable type hint

This commit is contained in:
Fabian 2022-10-24 09:24:12 +02:00
parent ed645c2fbe
commit 5ebd4225cc
3 changed files with 8 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from typing import Tuple, Optional
from typing import Tuple, Optional, Callable
import gym
import numpy as np
@ -19,8 +19,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
duration: float,
verbose: int = 1,
learn_sub_trajectories: bool = False,
replanning_schedule: Optional[callable] = None,
reward_aggregation: callable = np.sum
replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum
):
"""
gym.Wrapper for leveraging a black box approach with a trajectory generator.

View File

@ -1,5 +1,5 @@
from itertools import chain
from typing import Tuple, Type, Union, Optional
from typing import Tuple, Type, Union, Optional, Callable
import gym
import numpy as np
@ -123,7 +123,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
def test_aggregation(mp_type: str, reward_aggregation: callable):
def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]):
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
@ -327,4 +327,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
active_vel = vel[delay_time_steps: joint_time_steps - 2]
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])

View File

@ -1,4 +1,5 @@
from itertools import chain
from types import FunctionType
from typing import Tuple, Type, Union, Optional
import gym
@ -120,7 +121,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
{'basis_generator_type': 'rbf'}, seed=SEED)
assert env.do_replanning
assert env.replanning_schedule
assert callable(env.replanning_schedule)
# This also verifies we are not adding the TimeAwareObservationWrapper twice
assert env.observation_space == env_step.observation_space