diff --git a/alr_envs/alr/classic_control/base_reacher/base_reacher.py b/alr_envs/alr/classic_control/base_reacher/base_reacher.py index 1b1ad19..e76ce85 100644 --- a/alr_envs/alr/classic_control/base_reacher/base_reacher.py +++ b/alr_envs/alr/classic_control/base_reacher/base_reacher.py @@ -1,10 +1,11 @@ -from typing import Iterable, Union from abc import ABC, abstractmethod +from typing import Union + import gym -import matplotlib.pyplot as plt import numpy as np from gym import spaces from gym.utils import seeding + from alr_envs.alr.classic_control.utils import intersect diff --git a/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py deleted file mode 100644 index e249a71..0000000 --- a/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Tuple, Union - -import numpy as np - -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper - - -class MPWrapper(RawInterfaceWrapper): - - def get_context_mask(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [self.env.initial_width is None], # hole width - # [self.env.hole_depth is None], # hole depth - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_pos - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_vel diff --git a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py index 68d203f..9f40292 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py @@ -6,8 +6,8 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): - @property - def context_mask(self) -> np.ndarray: + + def context_mask(self): return np.hstack([ [self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # sin @@ -25,10 +25,6 @@ class MPWrapper(RawInterfaceWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.current_vel - @property - def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - @property def dt(self) -> Union[float, int]: return self.env.dt diff --git a/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py deleted file mode 100644 index 9f40292..0000000 --- a/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Tuple, Union - -import numpy as np - -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper - - -class MPWrapper(RawInterfaceWrapper): - - def context_mask(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_pos - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_vel - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index c1a25d3..6a15beb 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -1,2 +1,2 @@ from .mp_wrapper import MPWrapper -from .new_mp_wrapper import MPWrapper as NewMPWrapper \ No newline at end of file +from .mp_wrapper import MPWrapper as NewMPWrapper \ No newline at end of file diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py index e51843c..966be23 100644 --- a/alr_envs/alr/mujoco/reacher/mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Tuple import numpy as np @@ -8,37 +8,21 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): @property - def context_mask(self) -> np.ndarray: + def context_mask(self): return np.concatenate([ - [False] * self.n_links, # cos - [False] * self.n_links, # sin + [False] * self.env.n_links, # cos + [False] * self.env.n_links, # sin [True] * 2, # goal position - [False] * self.n_links, # angular velocity + [False] * self.env.n_links, # angular velocity [False] * 3, # goal distance # self.get_body_com("target"), # only return target to make problem harder [False], # step ]) - # @property - # def active_obs(self): - # return np.concatenate([ - # [True] * self.n_links, # cos, True - # [True] * self.n_links, # sin, True - # [True] * 2, # goal position - # [True] * self.n_links, # angular velocity, True - # [True] * 3, # goal distance - # # self.get_body_com("target"), # only return target to make problem harder - # [False], # step - # ]) + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qpos.flat[:self.env.n_links] @property - def current_vel(self) -> Union[float, int, np.ndarray]: - return self.sim.data.qvel.flat[:self.n_links] - - @property - def current_pos(self) -> Union[float, int, np.ndarray]: - return self.sim.data.qpos.flat[:self.n_links] - - @property - def dt(self) -> Union[float, int]: - return self.env.dt \ No newline at end of file + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel.flat[:self.env.n_links] diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py deleted file mode 100644 index 6b50d80..0000000 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ /dev/null @@ -1,28 +0,0 @@ -from alr_envs.mp.black_box_wrapper import BlackBoxWrapper -from typing import Union, Tuple -import numpy as np - -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper - - -class MPWrapper(RawInterfaceWrapper): - - @property - def context_mask(self): - return np.concatenate([ - [False] * self.env.n_links, # cos - [False] * self.env.n_links, # sin - [True] * 2, # goal position - [False] * self.env.n_links, # angular velocity - [False] * 3, # goal distance - # self.get_body_com("target"), # only return target to make problem harder - [False], # step - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qpos.flat[:self.env.n_links] - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qvel.flat[:self.env.n_links] diff --git a/alr_envs/mp/black_box_wrapper.py b/alr_envs/mp/black_box_wrapper.py index 0c2a7c8..f1ba41f 100644 --- a/alr_envs/mp/black_box_wrapper.py +++ b/alr_envs/mp/black_box_wrapper.py @@ -11,7 +11,7 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.utils.utils import get_numpy -class BlackBoxWrapper(gym.ObservationWrapper, ABC): +class BlackBoxWrapper(gym.ObservationWrapper): def __init__(self, env: RawInterfaceWrapper, @@ -34,9 +34,8 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory reward, default summation over all values. """ - super().__init__() + super().__init__(env) - self.env = env self.duration = duration self.learn_sub_trajectories = learn_sub_trajectories self.replanning_schedule = replanning_schedule