From 5ebd4225ccac57598cde1dd20d71cb3ec858c576 Mon Sep 17 00:00:00 2001 From: Fabian Date: Mon, 24 Oct 2022 09:24:12 +0200 Subject: [PATCH] Adjusted Callable type hint --- fancy_gym/black_box/black_box_wrapper.py | 7 ++++--- test/test_black_box.py | 5 ++--- test/test_replanning_sequencing.py | 3 ++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 7652bd5..16e4017 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional +from typing import Tuple, Optional, Callable import gym import numpy as np @@ -19,8 +19,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): duration: float, verbose: int = 1, learn_sub_trajectories: bool = False, - replanning_schedule: Optional[callable] = None, - reward_aggregation: callable = np.sum + replanning_schedule: Optional[ + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, + reward_aggregation: Callable[[np.ndarray], float] = np.sum ): """ gym.Wrapper for leveraging a black box approach with a trajectory generator. diff --git a/test/test_black_box.py b/test/test_black_box.py index f1b360b..d5e3a88 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import Tuple, Type, Union, Optional +from typing import Tuple, Type, Union, Optional, Callable import gym import numpy as np @@ -123,7 +123,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): @pytest.mark.parametrize('mp_type', ['promp', 'dmp']) @pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])]) -def test_aggregation(mp_type: str, reward_aggregation: callable): +def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]): env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation}, {'trajectory_generator_type': mp_type}, {'controller_type': 'motor'}, @@ -327,4 +327,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float): active_vel = vel[delay_time_steps: joint_time_steps - 2] assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0]) assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0]) - diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index 64045a5..a42bb65 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -1,4 +1,5 @@ from itertools import chain +from types import FunctionType from typing import Tuple, Type, Union, Optional import gym @@ -120,7 +121,7 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra {'basis_generator_type': 'rbf'}, seed=SEED) assert env.do_replanning - assert env.replanning_schedule + assert callable(env.replanning_schedule) # This also verifies we are not adding the TimeAwareObservationWrapper twice assert env.observation_space == env_step.observation_space