Adjusted Callable type hint

This commit is contained in:
Fabian 2022-10-24 09:24:12 +02:00
parent ed645c2fbe
commit 5ebd4225cc
3 changed files with 8 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from typing import Tuple, Optional from typing import Tuple, Optional, Callable
import gym import gym
import numpy as np import numpy as np
@ -19,8 +19,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
duration: float, duration: float,
verbose: int = 1, verbose: int = 1,
learn_sub_trajectories: bool = False, learn_sub_trajectories: bool = False,
replanning_schedule: Optional[callable] = None, replanning_schedule: Optional[
reward_aggregation: callable = np.sum Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum
): ):
""" """
gym.Wrapper for leveraging a black box approach with a trajectory generator. gym.Wrapper for leveraging a black box approach with a trajectory generator.

View File

@ -1,5 +1,5 @@
from itertools import chain from itertools import chain
from typing import Tuple, Type, Union, Optional from typing import Tuple, Type, Union, Optional, Callable
import gym import gym
import numpy as np import numpy as np
@ -123,7 +123,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
@pytest.mark.parametrize('mp_type', ['promp', 'dmp']) @pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])]) @pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
def test_aggregation(mp_type: str, reward_aggregation: callable): def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]):
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation}, env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
{'trajectory_generator_type': mp_type}, {'trajectory_generator_type': mp_type},
{'controller_type': 'motor'}, {'controller_type': 'motor'},
@ -327,4 +327,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
active_vel = vel[delay_time_steps: joint_time_steps - 2] active_vel = vel[delay_time_steps: joint_time_steps - 2]
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0]) assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0]) assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])

View File

@ -1,4 +1,5 @@
from itertools import chain from itertools import chain
from types import FunctionType
from typing import Tuple, Type, Union, Optional from typing import Tuple, Type, Union, Optional
import gym import gym
@ -120,7 +121,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
{'basis_generator_type': 'rbf'}, seed=SEED) {'basis_generator_type': 'rbf'}, seed=SEED)
assert env.do_replanning assert env.do_replanning
assert env.replanning_schedule assert callable(env.replanning_schedule)
# This also verifies we are not adding the TimeAwareObservationWrapper twice # This also verifies we are not adding the TimeAwareObservationWrapper twice
assert env.observation_space == env_step.observation_space assert env.observation_space == env_step.observation_space