From 3273f455c538ea1c5a3cea66679ff55c9f475f90 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 30 Jun 2022 14:08:54 +0200 Subject: [PATCH] wrappers updated --- alr_envs/__init__.py | 2 +- alr_envs/alr/__init__.py | 8 +- .../hole_reacher/mp_wrapper.py | 24 +-- .../hole_reacher/new_mp_wrapper.py | 31 ---- .../simple_reacher/mp_wrapper.py | 12 +- .../simple_reacher/new_mp_wrapper.py | 31 ---- .../viapoint_reacher/mp_wrapper.py | 6 +- .../viapoint_reacher/new_mp_wrapper.py | 2 - alr_envs/alr/mujoco/__init__.py | 11 +- alr_envs/alr/mujoco/ant_jump/mp_wrapper.py | 6 +- .../ball_in_a_cup/ball_in_a_cup_mp_wrapper.py | 6 +- alr_envs/alr/mujoco/beerpong/beerpong.py | 6 +- .../alr/mujoco/beerpong/beerpong_reward.py | 171 ------------------ .../mujoco/beerpong/beerpong_reward_simple.py | 141 --------------- .../alr/mujoco/beerpong/beerpong_simple.py | 166 ----------------- alr_envs/alr/mujoco/beerpong/mp_wrapper.py | 6 +- .../mujoco/half_cheetah_jump/mp_wrapper.py | 6 +- .../alr/mujoco/hopper_jump/hopper_jump.py | 3 +- alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py | 6 +- .../alr/mujoco/hopper_throw/mp_wrapper.py | 6 +- .../alr/mujoco/hopper_throw/new_mp_wrapper.py | 2 - alr_envs/alr/mujoco/reacher/__init__.py | 3 +- alr_envs/alr/mujoco/reacher/mp_wrapper.py | 7 +- alr_envs/alr/mujoco/reacher/new_mp_wrapper.py | 14 +- .../alr/mujoco/table_tennis/mp_wrapper.py | 6 +- .../alr/mujoco/walker_2d_jump/mp_wrapper.py | 6 +- .../mujoco/walker_2d_jump/new_mp_wrapper.py | 2 - .../dmc/manipulation/reach_site/mp_wrapper.py | 6 +- alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py | 6 +- alr_envs/dmc/suite/cartpole/mp_wrapper.py | 7 +- alr_envs/dmc/suite/reacher/mp_wrapper.py | 6 +- alr_envs/examples/examples_dmc.py | 2 +- alr_envs/examples/examples_metaworld.py | 2 +- ...ves.py => examples_movement_primitives.py} | 29 ++- alr_envs/examples/examples_open_ai.py | 2 +- alr_envs/examples/pd_control_gain_tuning.py | 2 +- alr_envs/meta/goal_change_mp_wrapper.py | 6 +- .../goal_endeffector_change_mp_wrapper.py | 6 +- .../meta/goal_object_change_mp_wrapper.py | 6 +- alr_envs/meta/object_change_mp_wrapper.py | 6 +- alr_envs/mp/black_box_wrapper.py | 50 +++-- alr_envs/mp/raw_interface_wrapper.py | 4 +- .../continuous_mountain_car/mp_wrapper.py | 5 +- .../open_ai/mujoco/reacher_v2/mp_wrapper.py | 9 +- alr_envs/open_ai/robotics/fetch/mp_wrapper.py | 4 +- alr_envs/utils/make_env_helpers.py | 71 +++++--- alr_envs/utils/utils.py | 22 +++ 47 files changed, 219 insertions(+), 722 deletions(-) delete mode 100644 alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py delete mode 100644 alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py delete mode 100644 alr_envs/alr/mujoco/beerpong/beerpong_reward.py delete mode 100644 alr_envs/alr/mujoco/beerpong/beerpong_reward_simple.py delete mode 100644 alr_envs/alr/mujoco/beerpong/beerpong_simple.py rename alr_envs/examples/{examples_motion_primitives.py => examples_movement_primitives.py} (82%) diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 858a66c..cf910b9 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -1,6 +1,6 @@ from alr_envs import dmc, meta, open_ai -from alr_envs.utils.make_env_helpers import make, make_dmp_env, make_promp_env, make_rank from alr_envs.utils import make_dmc +from alr_envs.utils.make_env_helpers import make, make_bb, make_rank # Convenience function for all MP environments from .alr import ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 7a7db6d..607ef18 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -406,8 +406,6 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0") kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_via_point_reacher_promp['wrappers'].append('TODO') # TODO -kwargs_dict_via_point_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 -kwargs_dict_via_point_reacher_promp['phase_generator_kwargs']['tau'] = 2 kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0" register( @@ -448,10 +446,10 @@ for _v in _versions: kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_hole_reacher_promp['wrappers'].append('TODO') # TODO kwargs_dict_hole_reacher_promp['ep_wrapper_kwargs']['weight_scale'] = 2 - kwargs_dict_hole_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 - kwargs_dict_hole_reacher_promp['phase_generator_kwargs']['tau'] = 2 + # kwargs_dict_hole_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 + # kwargs_dict_hole_reacher_promp['phase_generator_kwargs']['tau'] = 2 kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' - kwargs_dict_hole_reacher_promp['basis_generator_kwargs']['num_basis'] = 5 + # kwargs_dict_hole_reacher_promp['basis_generator_kwargs']['num_basis'] = 5 kwargs_dict_hole_reacher_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, diff --git a/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py index feb545f..e249a71 100644 --- a/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): - @property - def active_obs(self): +class MPWrapper(RawInterfaceWrapper): + + def get_context_mask(self): return np.hstack([ [self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # sin @@ -18,14 +18,6 @@ class MPWrapper(MPEnvWrapper): [False] # env steps ]) - # @property - # def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - # return self._joint_angles.copy() - # - # @property - # def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - # return self._angle_velocity.copy() - @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.current_pos @@ -33,11 +25,3 @@ class MPWrapper(MPEnvWrapper): @property def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.current_vel - - @property - def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py deleted file mode 100644 index 1f1d198..0000000 --- a/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Tuple, Union - -import numpy as np - -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper - - -class NewMPWrapper(RawInterfaceWrapper): - - def get_context_mask(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [self.env.initial_width is None], # hole width - # [self.env.hole_depth is None], # hole depth - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_pos - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_vel - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py index 4b71e3a..30b0985 100644 --- a/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): - @property - def active_obs(self): +class MPWrapper(RawInterfaceWrapper): + + def context_mask(self): return np.hstack([ [self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # sin @@ -24,10 +24,6 @@ class MPWrapper(MPEnvWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.current_vel - @property - def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - @property def dt(self) -> Union[float, int]: return self.env.dt diff --git a/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py deleted file mode 100644 index c1497e6..0000000 --- a/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Tuple, Union - -import numpy as np - -from mp_env_api import MPEnvWrapper - -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper - - -class MPWrapper(RawInterfaceWrapper): - - def context_mask(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_pos - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.current_vel - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py index 6b3e85d..68d203f 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # sin diff --git a/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py index f02dfe1..9f40292 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py @@ -2,8 +2,6 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper - from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper diff --git a/alr_envs/alr/mujoco/__init__.py b/alr_envs/alr/mujoco/__init__.py index 2885321..f2f4536 100644 --- a/alr_envs/alr/mujoco/__init__.py +++ b/alr_envs/alr/mujoco/__init__.py @@ -1,13 +1,12 @@ -from .reacher.balancing import BalancingEnv +from .ant_jump.ant_jump import ALRAntJumpEnv from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv -from .table_tennis.tt_gym import TTEnvGym -from .beerpong.beerpong import ALRBeerBongEnv, ALRBeerBongEnvStepBased, ALRBeerBongEnvStepBasedEpisodicReward, ALRBeerBongEnvFixedReleaseStep -from .ant_jump.ant_jump import ALRAntJumpEnv +from .beerpong.beerpong import ALRBeerBongEnv from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv -from .hopper_jump.hopper_jump import ALRHopperJumpEnv, ALRHopperJumpRndmPosEnv, ALRHopperXYJumpEnv, ALRHopperXYJumpEnvStepBased from .hopper_jump.hopper_jump_on_box import ALRHopperJumpOnBoxEnv from .hopper_throw.hopper_throw import ALRHopperThrowEnv from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv +from .reacher.alr_reacher import ALRReacherEnv +from .reacher.balancing import BalancingEnv +from .table_tennis.tt_gym import TTEnvGym from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv -from .reacher.alr_reacher import ALRReacherEnv \ No newline at end of file diff --git a/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py b/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py index 4967b64..4d5c0d6 100644 --- a/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * 111, # ant has 111 dimensional observation space !! [True] # goal height diff --git a/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py b/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py index 945fa8d..609858b 100644 --- a/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py +++ b/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class BallInACupMPWrapper(MPEnvWrapper): +class BallInACupMPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # TODO: @Max Filter observations correctly return np.hstack([ [False] * 7, # cos diff --git a/alr_envs/alr/mujoco/beerpong/beerpong.py b/alr_envs/alr/mujoco/beerpong/beerpong.py index 64d9e78..dfd6ea4 100644 --- a/alr_envs/alr/mujoco/beerpong/beerpong.py +++ b/alr_envs/alr/mujoco/beerpong/beerpong.py @@ -22,7 +22,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): def __init__( self, frame_skip=1, apply_gravity_comp=True, noisy=False, rndm_goal=False, cup_goal_pos=None - ): + ): cup_goal_pos = np.array(cup_goal_pos if cup_goal_pos is not None else [-0.3, -1.2, 0.840]) if cup_goal_pos.shape[0] == 2: @@ -154,7 +154,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): success=success, is_collided=is_collided, sim_crash=crash, table_contact_first=int(not self.reward_function.ball_ground_contact_first) - ) + ) infos.update(reward_infos) return ob, reward, done, infos @@ -176,7 +176,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): cup_goal_diff_top, self.sim.model.body_pos[self.cup_table_id][:2].copy(), [self._steps], - ]) + ]) @property def dt(self): diff --git a/alr_envs/alr/mujoco/beerpong/beerpong_reward.py b/alr_envs/alr/mujoco/beerpong/beerpong_reward.py deleted file mode 100644 index dc39ca8..0000000 --- a/alr_envs/alr/mujoco/beerpong/beerpong_reward.py +++ /dev/null @@ -1,171 +0,0 @@ -import numpy as np - - -class BeerPongReward: - def __init__(self): - - self.robot_collision_objects = ["wrist_palm_link_convex_geom", - "wrist_pitch_link_convex_decomposition_p1_geom", - "wrist_pitch_link_convex_decomposition_p2_geom", - "wrist_pitch_link_convex_decomposition_p3_geom", - "wrist_yaw_link_convex_decomposition_p1_geom", - "wrist_yaw_link_convex_decomposition_p2_geom", - "forearm_link_convex_decomposition_p1_geom", - "forearm_link_convex_decomposition_p2_geom", - "upper_arm_link_convex_decomposition_p1_geom", - "upper_arm_link_convex_decomposition_p2_geom", - "shoulder_link_convex_decomposition_p1_geom", - "shoulder_link_convex_decomposition_p2_geom", - "shoulder_link_convex_decomposition_p3_geom", - "base_link_convex_geom", "table_contact_geom"] - - self.cup_collision_objects = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "cup_geom_table6", - "cup_geom_table7", "cup_geom_table8", "cup_geom_table9", "cup_geom_table10", - # "cup_base_table", "cup_base_table_contact", - "cup_geom_table15", - "cup_geom_table16", - "cup_geom_table17", "cup_geom1_table8", - # "cup_base_table_contact", - # "cup_base_table" - ] - - - self.ball_traj = None - self.dists = None - self.dists_final = None - self.costs = None - self.action_costs = None - self.angle_rewards = None - self.cup_angles = None - self.cup_z_axes = None - self.collision_penalty = 500 - self.reset(None) - - def reset(self, context): - self.ball_traj = [] - self.dists = [] - self.dists_final = [] - self.costs = [] - self.action_costs = [] - self.angle_rewards = [] - self.cup_angles = [] - self.cup_z_axes = [] - self.ball_ground_contact = False - self.ball_table_contact = False - self.ball_wall_contact = False - self.ball_cup_contact = False - - def compute_reward(self, env, action): - self.ball_id = env.sim.model._body_name2id["ball"] - self.ball_collision_id = env.sim.model._geom_name2id["ball_geom"] - self.goal_id = env.sim.model._site_name2id["cup_goal_table"] - self.goal_final_id = env.sim.model._site_name2id["cup_goal_final_table"] - self.cup_collision_ids = [env.sim.model._geom_name2id[name] for name in self.cup_collision_objects] - self.cup_table_id = env.sim.model._body_name2id["cup_table"] - self.table_collision_id = env.sim.model._geom_name2id["table_contact_geom"] - self.wall_collision_id = env.sim.model._geom_name2id["wall"] - self.cup_table_collision_id = env.sim.model._geom_name2id["cup_base_table_contact"] - self.init_ball_pos_site_id = env.sim.model._site_name2id["init_ball_pos_site"] - self.ground_collision_id = env.sim.model._geom_name2id["ground"] - self.robot_collision_ids = [env.sim.model._geom_name2id[name] for name in self.robot_collision_objects] - - goal_pos = env.sim.data.site_xpos[self.goal_id] - ball_pos = env.sim.data.body_xpos[self.ball_id] - ball_vel = env.sim.data.body_xvelp[self.ball_id] - goal_final_pos = env.sim.data.site_xpos[self.goal_final_id] - self.dists.append(np.linalg.norm(goal_pos - ball_pos)) - self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) - - action_cost = np.sum(np.square(action)) - self.action_costs.append(action_cost) - - ball_table_bounce = self._check_collision_single_objects(env.sim, self.ball_collision_id, - self.table_collision_id) - - if ball_table_bounce: # or ball_cup_table_cont or ball_wall_con - self.ball_table_contact = True - - ball_cup_cont = self._check_collision_with_set_of_objects(env.sim, self.ball_collision_id, - self.cup_collision_ids) - if ball_cup_cont: - self.ball_cup_contact = True - - ball_wall_cont = self._check_collision_single_objects(env.sim, self.ball_collision_id, self.wall_collision_id) - if ball_wall_cont and not self.ball_table_contact: - self.ball_wall_contact = True - - ball_ground_contact = self._check_collision_single_objects(env.sim, self.ball_collision_id, - self.ground_collision_id) - if ball_ground_contact and not self.ball_table_contact: - self.ball_ground_contact = True - - self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids) - if env._steps == env.ep_length - 1 or self._is_collided: - - min_dist = np.min(self.dists) - - ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id, self.cup_table_collision_id) - - cost_offset = 0 - - if self.ball_ground_contact: # or self.ball_wall_contact: - cost_offset += 2 - - if not self.ball_table_contact: - cost_offset += 2 - - if not ball_in_cup: - cost_offset += 2 - cost = cost_offset + min_dist ** 2 + 0.5 * self.dists_final[-1] ** 2 + 1e-4 * action_cost # + min_dist ** 2 - else: - if self.ball_cup_contact: - cost_offset += 1 - cost = cost_offset + self.dists_final[-1] ** 2 + 1e-4 * action_cost - - reward = - 1*cost - self.collision_penalty * int(self._is_collided) - success = ball_in_cup and not self.ball_ground_contact and not self.ball_wall_contact and not self.ball_cup_contact - else: - reward = - 1e-4 * action_cost - success = False - - infos = {} - infos["success"] = success - infos["is_collided"] = self._is_collided - infos["ball_pos"] = ball_pos.copy() - infos["ball_vel"] = ball_vel.copy() - infos["action_cost"] = 5e-4 * action_cost - - return reward, infos - - def _check_collision_single_objects(self, sim, id_1, id_2): - for coni in range(0, sim.data.ncon): - con = sim.data.contact[coni] - - collision = con.geom1 == id_1 and con.geom2 == id_2 - collision_trans = con.geom1 == id_2 and con.geom2 == id_1 - - if collision or collision_trans: - return True - return False - - def _check_collision_with_itself(self, sim, collision_ids): - col_1, col_2 = False, False - for j, id in enumerate(collision_ids): - col_1 = self._check_collision_with_set_of_objects(sim, id, collision_ids[:j]) - if j != len(collision_ids) - 1: - col_2 = self._check_collision_with_set_of_objects(sim, id, collision_ids[j + 1:]) - else: - col_2 = False - collision = True if col_1 or col_2 else False - return collision - - def _check_collision_with_set_of_objects(self, sim, id_1, id_list): - for coni in range(0, sim.data.ncon): - con = sim.data.contact[coni] - - collision = con.geom1 in id_list and con.geom2 == id_1 - collision_trans = con.geom1 == id_1 and con.geom2 in id_list - - if collision or collision_trans: - return True - return False \ No newline at end of file diff --git a/alr_envs/alr/mujoco/beerpong/beerpong_reward_simple.py b/alr_envs/alr/mujoco/beerpong/beerpong_reward_simple.py deleted file mode 100644 index fbe2163..0000000 --- a/alr_envs/alr/mujoco/beerpong/beerpong_reward_simple.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np -from alr_envs.alr.mujoco import alr_reward_fct - - -class BeerpongReward(alr_reward_fct.AlrReward): - def __init__(self, sim, sim_time): - - self.sim = sim - self.sim_time = sim_time - - self.collision_objects = ["cup_geom1", "cup_geom2", "wrist_palm_link_convex_geom", - "wrist_pitch_link_convex_decomposition_p1_geom", - "wrist_pitch_link_convex_decomposition_p2_geom", - "wrist_pitch_link_convex_decomposition_p3_geom", - "wrist_yaw_link_convex_decomposition_p1_geom", - "wrist_yaw_link_convex_decomposition_p2_geom", - "forearm_link_convex_decomposition_p1_geom", - "forearm_link_convex_decomposition_p2_geom"] - - self.ball_id = None - self.ball_collision_id = None - self.goal_id = None - self.goal_final_id = None - self.collision_ids = None - - self.ball_traj = None - self.dists = None - self.dists_ctxt = None - self.dists_final = None - self.costs = None - - self.reset(None) - - def reset(self, context): - self.ball_traj = np.zeros(shape=(self.sim_time, 3)) - self.dists = [] - self.dists_ctxt = [] - self.dists_final = [] - self.costs = [] - self.action_costs = [] - self.context = context - self.ball_in_cup = False - self.dist_ctxt = 5 - self.bounce_dist = 2 - self.min_dist = 2 - self.dist_final = 2 - self.table_contact = False - - self.ball_id = self.sim.model._body_name2id["ball"] - self.ball_collision_id = self.sim.model._geom_name2id["ball_geom"] - self.cup_robot_id = self.sim.model._site_name2id["cup_robot_final"] - self.goal_id = self.sim.model._site_name2id["cup_goal_table"] - self.goal_final_id = self.sim.model._site_name2id["cup_goal_final_table"] - self.collision_ids = [self.sim.model._geom_name2id[name] for name in self.collision_objects] - self.cup_table_id = self.sim.model._body_name2id["cup_table"] - self.bounce_table_id = self.sim.model._site_name2id["bounce_table"] - - def compute_reward(self, action, sim, step): - action_cost = np.sum(np.square(action)) - self.action_costs.append(action_cost) - - stop_sim = False - success = False - - if self.check_collision(sim): - reward = - 1e-2 * action_cost - 10 - stop_sim = True - return reward, success, stop_sim - - # Compute the current distance from the ball to the inner part of the cup - goal_pos = sim.data.site_xpos[self.goal_id] - ball_pos = sim.data.body_xpos[self.ball_id] - bounce_pos = sim.data.site_xpos[self.bounce_table_id] - goal_final_pos = sim.data.site_xpos[self.goal_final_id] - self.dists.append(np.linalg.norm(goal_pos - ball_pos)) - self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) - self.ball_traj[step, :] = ball_pos - - ball_in_cup = self.check_ball_in_cup(sim, self.ball_collision_id) - table_contact = self.check_ball_table_contact(sim, self.ball_collision_id) - - if table_contact and not self.table_contact: - self.bounce_dist = np.minimum((np.linalg.norm(bounce_pos - ball_pos)), 2) - self.table_contact = True - - if step == self.sim_time - 1: - min_dist = np.min(self.dists) - self.min_dist = min_dist - dist_final = self.dists_final[-1] - self.dist_final = dist_final - - cost = 0.33 * min_dist + 0.33 * dist_final + 0.33 * self.bounce_dist - reward = np.exp(-2 * cost) - 1e-2 * action_cost - success = self.bounce_dist < 0.05 and dist_final < 0.05 and ball_in_cup - else: - reward = - 1e-2 * action_cost - success = False - - return reward, success, stop_sim - - def _get_stage_wise_cost(self, ball_in_cup, min_dist, dist_final, dist_to_ctxt): - if not ball_in_cup: - cost = 3 + 2*(0.5 * min_dist**2 + 0.5 * dist_final**2) - else: - cost = 2 * dist_to_ctxt ** 2 - print('Context Distance:', dist_to_ctxt) - return cost - - def check_ball_table_contact(self, sim, ball_collision_id): - table_collision_id = sim.model._geom_name2id["table_contact_geom"] - for coni in range(0, sim.data.ncon): - con = sim.data.contact[coni] - collision = con.geom1 == table_collision_id and con.geom2 == ball_collision_id - collision_trans = con.geom1 == ball_collision_id and con.geom2 == table_collision_id - - if collision or collision_trans: - return True - return False - - def check_ball_in_cup(self, sim, ball_collision_id): - cup_base_collision_id = sim.model._geom_name2id["cup_base_table_contact"] - for coni in range(0, sim.data.ncon): - con = sim.data.contact[coni] - - collision = con.geom1 == cup_base_collision_id and con.geom2 == ball_collision_id - collision_trans = con.geom1 == ball_collision_id and con.geom2 == cup_base_collision_id - - if collision or collision_trans: - return True - return False - - def check_collision(self, sim): - for coni in range(0, sim.data.ncon): - con = sim.data.contact[coni] - - collision = con.geom1 in self.collision_ids and con.geom2 == self.ball_collision_id - collision_trans = con.geom1 == self.ball_collision_id and con.geom2 in self.collision_ids - - if collision or collision_trans: - return True - return False diff --git a/alr_envs/alr/mujoco/beerpong/beerpong_simple.py b/alr_envs/alr/mujoco/beerpong/beerpong_simple.py deleted file mode 100644 index 1708d38..0000000 --- a/alr_envs/alr/mujoco/beerpong/beerpong_simple.py +++ /dev/null @@ -1,166 +0,0 @@ -from gym import utils -import os -import numpy as np -from gym.envs.mujoco import MujocoEnv - - -class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): - def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None): - self._steps = 0 - - self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "beerpong" + ".xml") - - self.start_pos = np.array([0.0, 1.35, 0.0, 1.18, 0.0, -0.786, -1.59]) - self.start_vel = np.zeros(7) - - self._q_pos = [] - self._q_vel = [] - # self.weight_matrix_scale = 50 - self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.]) - self.p_gains = 1 / self.max_ctrl * np.array([200, 300, 100, 100, 10, 10, 2.5]) - self.d_gains = 1 / self.max_ctrl * np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05]) - - self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7]) - self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7]) - - self.context = None - - # alr_mujoco_env.AlrMujocoEnv.__init__(self, - # self.xml_path, - # apply_gravity_comp=apply_gravity_comp, - # n_substeps=n_substeps) - - self.sim_time = 8 # seconds - # self.sim_steps = int(self.sim_time / self.dt) - if reward_function is None: - from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward - reward_function = BeerpongReward - self.reward_function = reward_function(self.sim, self.sim_steps) - self.cup_robot_id = self.sim.model._site_name2id["cup_robot_final"] - self.ball_id = self.sim.model._body_name2id["ball"] - self.cup_table_id = self.sim.model._body_name2id["cup_table"] - # self.bounce_table_id = self.sim.model._body_name2id["bounce_table"] - - MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps) - utils.EzPickle.__init__(self) - - @property - def current_pos(self): - return self.sim.data.qpos[0:7].copy() - - @property - def current_vel(self): - return self.sim.data.qvel[0:7].copy() - - def configure(self, context): - if context is None: - context = np.array([0, -2, 0.840]) - self.context = context - self.reward_function.reset(context) - - def reset_model(self): - init_pos_all = self.init_qpos.copy() - init_pos_robot = self.start_pos - init_vel = np.zeros_like(init_pos_all) - - self._steps = 0 - self._q_pos = [] - self._q_vel = [] - - start_pos = init_pos_all - start_pos[0:7] = init_pos_robot - # start_pos[7:] = np.copy(self.sim.data.site_xpos[self.cup_robot_id, :]) + np.array([0., 0.0, 0.05]) - - self.set_state(start_pos, init_vel) - - ball_pos = np.copy(self.sim.data.site_xpos[self.cup_robot_id, :]) + np.array([0., 0.0, 0.05]) - self.sim.model.body_pos[self.ball_id] = ball_pos.copy() - self.sim.model.body_pos[self.cup_table_id] = self.context.copy() - # self.sim.model.body_pos[self.bounce_table_id] = self.context.copy() - - self.sim.forward() - - return self._get_obs() - - def step(self, a): - reward_dist = 0.0 - angular_vel = 0.0 - reward_ctrl = - np.square(a).sum() - action_cost = np.sum(np.square(a)) - - crash = self.do_simulation(a, self.frame_skip) - joint_cons_viol = self.check_traj_in_joint_limits() - - self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy()) - self._q_vel.append(self.sim.data.qvel[0:7].ravel().copy()) - - ob = self._get_obs() - - if not crash and not joint_cons_viol: - reward, success, stop_sim = self.reward_function.compute_reward(a, self.sim, self._steps) - done = success or self._steps == self.sim_steps - 1 or stop_sim - self._steps += 1 - else: - reward = -10 - 1e-2 * action_cost - success = False - done = True - return ob, reward, done, dict(reward_dist=reward_dist, - reward_ctrl=reward_ctrl, - velocity=angular_vel, - traj=self._q_pos, is_success=success, - is_collided=crash or joint_cons_viol) - - def check_traj_in_joint_limits(self): - return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min) - - def extend_des_pos(self, des_pos): - des_pos_full = self.start_pos.copy() - des_pos_full[1] = des_pos[0] - des_pos_full[3] = des_pos[1] - des_pos_full[5] = des_pos[2] - return des_pos_full - - def extend_des_vel(self, des_vel): - des_vel_full = self.start_vel.copy() - des_vel_full[1] = des_vel[0] - des_vel_full[3] = des_vel[1] - des_vel_full[5] = des_vel[2] - return des_vel_full - - def _get_obs(self): - theta = self.sim.data.qpos.flat[:7] - return np.concatenate([ - np.cos(theta), - np.sin(theta), - # self.get_body_com("target"), # only return target to make problem harder - [self._steps], - ]) - - - -if __name__ == "__main__": - env = ALRBeerpongEnv() - ctxt = np.array([0, -2, 0.840]) # initial - - env.configure(ctxt) - env.reset() - env.render() - for i in range(16000): - # test with random actions - ac = 0.0 * env.action_space.sample()[0:7] - ac[1] = -0.01 - ac[3] = - 0.01 - ac[5] = -0.01 - # ac = env.start_pos - # ac[0] += np.pi/2 - obs, rew, d, info = env.step(ac) - env.render() - - print(rew) - - if d: - break - - env.close() - diff --git a/alr_envs/alr/mujoco/beerpong/mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py index 022490c..40c371b 100644 --- a/alr_envs/alr/mujoco/beerpong/mp_wrapper.py +++ b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * 7, # cos [False] * 7, # sin diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py b/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py index 6179b07..f9a298a 100644 --- a/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * 17, [True] # goal height diff --git a/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py b/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py index 5cd234c..025bb8d 100644 --- a/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py +++ b/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py @@ -54,7 +54,8 @@ class ALRHopperJumpEnv(HopperEnv): self.current_step += 1 self.do_simulation(action, self.frame_skip) height_after = self.get_body_com("torso")[2] - site_pos_after = self.sim.data.site_xpos[self.model.site_name2id('foot_site')].copy() + # site_pos_after = self.sim.data.site_xpos[self.model.site_name2id('foot_site')].copy() + site_pos_after = self.get_body_com('foot_site') self.max_height = max(height_after, self.max_height) ctrl_cost = self.control_cost(action) diff --git a/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py b/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py index 36b7158..e3279aa 100644 --- a/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * (5 + int(not self.exclude_current_positions_from_observation)), # position [False] * 6, # velocity diff --git a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py index 909e00a..f5bf08d 100644 --- a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * 17, [True] # goal pos diff --git a/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py index 049c2f0..a8cd696 100644 --- a/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py @@ -2,8 +2,6 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper - from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index 989b5a9..c1a25d3 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -1 +1,2 @@ -from .mp_wrapper import MPWrapper \ No newline at end of file +from .mp_wrapper import MPWrapper +from .new_mp_wrapper import MPWrapper as NewMPWrapper \ No newline at end of file diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py index 3b655d4..e51843c 100644 --- a/alr_envs/alr/mujoco/reacher/mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -1,13 +1,14 @@ from typing import Union import numpy as np -from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.concatenate([ [False] * self.n_links, # cos [False] * self.n_links, # sin diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py index 54910e5..6b50d80 100644 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py @@ -8,12 +8,6 @@ from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qpos.flat[:self.env.n_links] - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qvel.flat[:self.env.n_links] - def context_mask(self): return np.concatenate([ [False] * self.env.n_links, # cos @@ -24,3 +18,11 @@ class MPWrapper(RawInterfaceWrapper): # self.get_body_com("target"), # only return target to make problem harder [False], # step ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qpos.flat[:self.env.n_links] + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel.flat[:self.env.n_links] diff --git a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py index 473583f..408124a 100644 --- a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py +++ b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # TODO: @Max Filter observations correctly return np.hstack([ [False] * 7, # Joint Pos diff --git a/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py b/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py index 445fa40..0c2dba5 100644 --- a/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py @@ -2,12 +2,12 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: return np.hstack([ [False] * 17, [True] # goal pos diff --git a/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py index dde928f..96b0739 100644 --- a/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py @@ -2,8 +2,6 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper - from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper diff --git a/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py b/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py index 2d03f7b..6d5029e 100644 --- a/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py +++ b/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # Joint and target positions are randomized, velocities are always set to 0. return np.hstack([ [True] * 3, # target position diff --git a/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py index fb068b3..9687bed 100644 --- a/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py +++ b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # Besides the ball position, the environment is always set to 0. return np.hstack([ [False] * 2, # cup position diff --git a/alr_envs/dmc/suite/cartpole/mp_wrapper.py b/alr_envs/dmc/suite/cartpole/mp_wrapper.py index 1ca99f5..3f16d24 100644 --- a/alr_envs/dmc/suite/cartpole/mp_wrapper.py +++ b/alr_envs/dmc/suite/cartpole/mp_wrapper.py @@ -2,18 +2,17 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): def __init__(self, env, n_poles: int = 1): self.n_poles = n_poles super().__init__(env) - @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # Besides the ball position, the environment is always set to 0. return np.hstack([ [True], # slider position diff --git a/alr_envs/dmc/suite/reacher/mp_wrapper.py b/alr_envs/dmc/suite/reacher/mp_wrapper.py index 86bc992..ac857c1 100644 --- a/alr_envs/dmc/suite/reacher/mp_wrapper.py +++ b/alr_envs/dmc/suite/reacher/mp_wrapper.py @@ -2,13 +2,13 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # Joint and target positions are randomized, velocities are always set to 0. return np.hstack([ [True] * 2, # joint position diff --git a/alr_envs/examples/examples_dmc.py b/alr_envs/examples/examples_dmc.py index 5658b1f..41d2231 100644 --- a/alr_envs/examples/examples_dmc.py +++ b/alr_envs/examples/examples_dmc.py @@ -59,7 +59,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): # Base DMC name, according to structure of above example base_env = "ball_in_cup-catch" - # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. + # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. # You can also add other gym.Wrappers in case they are needed. wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper] mp_kwargs = { diff --git a/alr_envs/examples/examples_metaworld.py b/alr_envs/examples/examples_metaworld.py index 3e040cc..f179149 100644 --- a/alr_envs/examples/examples_metaworld.py +++ b/alr_envs/examples/examples_metaworld.py @@ -62,7 +62,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): # Base MetaWorld name, according to structure of above example base_env = "button-press-v2" - # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. + # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. # You can also add other gym.Wrappers in case they are needed. wrappers = [alr_envs.meta.goal_and_object_change.MPWrapper] mp_kwargs = { diff --git a/alr_envs/examples/examples_motion_primitives.py b/alr_envs/examples/examples_movement_primitives.py similarity index 82% rename from alr_envs/examples/examples_motion_primitives.py rename to alr_envs/examples/examples_movement_primitives.py index b9d355a..bf1f950 100644 --- a/alr_envs/examples/examples_motion_primitives.py +++ b/alr_envs/examples/examples_movement_primitives.py @@ -59,6 +59,17 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations= """ # Changing the traj_gen_kwargs is possible by providing them to gym. # E.g. here by providing way to many basis functions + # mp_dict = alr_envs.from_default_config('ALRReacher-v0', {'basis_generator_kwargs': {'num_basis': 10}}) + # mp_dict.update({'basis_generator_kwargs': {'num_basis': 10}}) + # mp_dict.update({'black_box_kwargs': {'learn_sub_trajectories': True}}) + # mp_dict.update({'black_box_kwargs': {'do_replanning': lambda pos, vel, t: lambda t: t % 100}}) + + # default env with promp and no learn_sub_trajectories and replanning + # env = alr_envs.make('ALRReacherProMP-v0', 1, n_links=7) + env = alr_envs.make('ALRReacherProMP-v0', 1, basis_generator_kwargs={'num_basis': 10}, n_links=7) + # env = alr_envs.make('ALRReacher-v0', seed=1, bb_kwargs=mp_dict, n_links=1) + # env = alr_envs.make_bb('ALRReacher-v0', **mp_dict) + mp_kwargs = { "num_dof": 5, "num_basis": 1000, @@ -110,7 +121,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): base_env = "alr_envs:HoleReacher-v1" - # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. + # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. # You can also add other gym.Wrappers in case they are needed. wrappers = [alr_envs.alr.classic_control.hole_reacher.MPWrapper] mp_kwargs = { @@ -148,14 +159,14 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': render = False - # DMP - example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) - - # ProMP - example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render) - - # DetProMP - example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) + # # DMP + # example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) + # + # # ProMP + # example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render) + # + # # DetProMP + # example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) # Altered basis functions example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) diff --git a/alr_envs/examples/examples_open_ai.py b/alr_envs/examples/examples_open_ai.py index 631a3a1..46dcf60 100644 --- a/alr_envs/examples/examples_open_ai.py +++ b/alr_envs/examples/examples_open_ai.py @@ -4,7 +4,7 @@ import alr_envs def example_mp(env_name, seed=1): """ Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. - For more information on motion primitive specific stuff, look at the trajectory_generator examples. + For more information on motion primitive specific stuff, look at the traj_gen examples. Args: env_name: ProMP env_id seed: seed diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index e05bad1..27cf8f8 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env def visualize(env): t = env.t - pos_features = env.trajectory_generator.basis_generator.basis(t) + pos_features = env.traj_gen.basis_generator.basis(t) plt.plot(t, pos_features) plt.show() diff --git a/alr_envs/meta/goal_change_mp_wrapper.py b/alr_envs/meta/goal_change_mp_wrapper.py index a558365..17495da 100644 --- a/alr_envs/meta/goal_change_mp_wrapper.py +++ b/alr_envs/meta/goal_change_mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): """ This Wrapper is for environments where merely the goal changes in the beginning and no secondary objects or end effectors are altered at the start of an episode. @@ -27,7 +27,7 @@ class MPWrapper(MPEnvWrapper): """ @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # This structure is the same for all metaworld environments. # Only the observations which change could differ return np.hstack([ diff --git a/alr_envs/meta/goal_endeffector_change_mp_wrapper.py b/alr_envs/meta/goal_endeffector_change_mp_wrapper.py index 8912a72..3a6ad1c 100644 --- a/alr_envs/meta/goal_endeffector_change_mp_wrapper.py +++ b/alr_envs/meta/goal_endeffector_change_mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): """ This Wrapper is for environments where merely the goal changes in the beginning and no secondary objects or end effectors are altered at the start of an episode. @@ -27,7 +27,7 @@ class MPWrapper(MPEnvWrapper): """ @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # This structure is the same for all metaworld environments. # Only the observations which change could differ return np.hstack([ diff --git a/alr_envs/meta/goal_object_change_mp_wrapper.py b/alr_envs/meta/goal_object_change_mp_wrapper.py index 63e16b7..97c64b8 100644 --- a/alr_envs/meta/goal_object_change_mp_wrapper.py +++ b/alr_envs/meta/goal_object_change_mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): """ This Wrapper is for environments where merely the goal changes in the beginning and no secondary objects or end effectors are altered at the start of an episode. @@ -27,7 +27,7 @@ class MPWrapper(MPEnvWrapper): """ @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # This structure is the same for all metaworld environments. # Only the observations which change could differ return np.hstack([ diff --git a/alr_envs/meta/object_change_mp_wrapper.py b/alr_envs/meta/object_change_mp_wrapper.py index 4293148..f832c9f 100644 --- a/alr_envs/meta/object_change_mp_wrapper.py +++ b/alr_envs/meta/object_change_mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): """ This Wrapper is for environments where merely the goal changes in the beginning and no secondary objects or end effectors are altered at the start of an episode. @@ -27,7 +27,7 @@ class MPWrapper(MPEnvWrapper): """ @property - def active_obs(self): + def context_mask(self) -> np.ndarray: # This structure is the same for all metaworld environments. # Only the observations which change could differ return np.hstack([ diff --git a/alr_envs/mp/black_box_wrapper.py b/alr_envs/mp/black_box_wrapper.py index 9e6a9e5..0c2a7c8 100644 --- a/alr_envs/mp/black_box_wrapper.py +++ b/alr_envs/mp/black_box_wrapper.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Tuple +from typing import Tuple, Union import gym import numpy as np @@ -16,7 +16,9 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): def __init__(self, env: RawInterfaceWrapper, trajectory_generator: MPInterface, tracking_controller: BaseController, - duration: float, verbose: int = 1, sequencing: bool = True, reward_aggregation: callable = np.sum): + duration: float, verbose: int = 1, learn_sub_trajectories: bool = False, + replanning_schedule: Union[None, callable] = None, + reward_aggregation: callable = np.sum): """ gym.Wrapper for leveraging a black box approach with a trajectory generator. @@ -26,6 +28,9 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): tracking_controller: Translates the desired trajectory to raw action sequences duration: Length of the trajectory of the movement primitive in seconds verbose: level of detail for returned values in info dict. + learn_sub_trajectories: Transforms full episode learning into learning sub-trajectories, similar to + step-based learning + replanning_schedule: callable that receives reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory reward, default summation over all values. """ @@ -33,21 +38,22 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): self.env = env self.duration = duration - self.sequencing = sequencing + self.learn_sub_trajectories = learn_sub_trajectories + self.replanning_schedule = replanning_schedule self.current_traj_steps = 0 # trajectory generation - self.trajectory_generator = trajectory_generator + self.traj_gen = trajectory_generator self.tracking_controller = tracking_controller # self.time_steps = np.linspace(0, self.duration, self.traj_steps) - # self.trajectory_generator.set_mp_times(self.time_steps) - self.trajectory_generator.set_duration(np.array([self.duration]), np.array([self.dt])) + # self.traj_gen.set_mp_times(self.time_steps) + self.traj_gen.set_duration(np.array([self.duration]), np.array([self.dt])) # reward computation self.reward_aggregation = reward_aggregation # spaces - self.return_context_observation = not (self.sequencing) # TODO or we_do_replanning?) + self.return_context_observation = not (self.learn_sub_trajectories or replanning_schedule) self.traj_gen_action_space = self.get_traj_gen_action_space() self.action_space = self.get_action_space() self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask], @@ -60,26 +66,26 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): def observation(self, observation): # return context space if we are - return observation[self.context_mask] if self.return_context_observation else observation + return observation[self.env.context_mask] if self.return_context_observation else observation def get_trajectory(self, action: np.ndarray) -> Tuple: clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high) - self.trajectory_generator.set_params(clipped_params) - # if self.trajectory_generator.learn_tau: - # self.trajectory_generator.set_mp_duration(self.trajectory_generator.tau, np.array([self.dt])) - # TODO: Bruce said DMP, ProMP, ProDMP can have 0 bc_time - self.trajectory_generator.set_boundary_conditions(bc_time=np.zeros((1,)), bc_pos=self.current_pos, - bc_vel=self.current_vel) + self.traj_gen.set_params(clipped_params) + # TODO: Bruce said DMP, ProMP, ProDMP can have 0 bc_time for sequencing + # TODO Check with Bruce for replanning + self.traj_gen.set_boundary_conditions( + bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt, + bc_pos=self.current_pos, bc_vel=self.current_vel) # TODO: is this correct for replanning? Do we need to adjust anything here? - self.trajectory_generator.set_duration(None if self.sequencing else self.duration, np.array([self.dt])) - traj_dict = self.trajectory_generator.get_trajs(get_pos=True, get_vel=True) + self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, np.array([self.dt])) + traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] return get_numpy(trajectory_tensor), get_numpy(velocity_tensor) def get_traj_gen_action_space(self): - """This function can be used to set up an individual space for the parameters of the trajectory_generator.""" - min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds() + """This function can be used to set up an individual space for the parameters of the traj_gen.""" + min_action_bounds, max_action_bounds = self.traj_gen.get_param_bounds() mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), dtype=np.float32) return mp_action_space @@ -134,8 +140,11 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): if self.render_kwargs: self.render(**self.render_kwargs) - if done or self.env.do_replanning(self.current_pos, self.current_vel, obs, c_action, - t + 1 + self.current_traj_steps): + if done: + break + + if self.replanning_schedule and self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + t + 1 + self.current_traj_steps): break infos.update({k: v[:t + 1] for k, v in infos.items()}) @@ -160,6 +169,7 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): def reset(self, **kwargs): self.current_traj_steps = 0 + super(BlackBoxWrapper, self).reset(**kwargs) def plot_trajs(self, des_trajs, des_vels): import matplotlib.pyplot as plt diff --git a/alr_envs/mp/raw_interface_wrapper.py b/alr_envs/mp/raw_interface_wrapper.py index d57ff9a..fdbc2f7 100644 --- a/alr_envs/mp/raw_interface_wrapper.py +++ b/alr_envs/mp/raw_interface_wrapper.py @@ -62,8 +62,8 @@ class RawInterfaceWrapper(gym.Wrapper): include other actions like ball releasing time for the beer pong environment. This only needs to be overwritten if the action space is modified. Args: - action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if - specified, else only trajectory_generator parameters + action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if + specified, else only traj_gen parameters Returns: Tuple: mp_arguments and other arguments diff --git a/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py index 2a2357a..189563c 100644 --- a/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py +++ b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py @@ -1,10 +1,11 @@ from typing import Union import numpy as np -from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property def current_vel(self) -> Union[float, int, np.ndarray]: return np.array([self.state[1]]) diff --git a/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py index 16202e5..9d627b6 100644 --- a/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py +++ b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py @@ -1,10 +1,11 @@ from typing import Union import numpy as np -from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property def current_vel(self) -> Union[float, int, np.ndarray]: @@ -13,7 +14,3 @@ class MPWrapper(MPEnvWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray]: return self.sim.data.qpos[:2] - - @property - def dt(self) -> Union[float, int]: - return self.env.dt \ No newline at end of file diff --git a/alr_envs/open_ai/robotics/fetch/mp_wrapper.py b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py index 218e175..7a7bed6 100644 --- a/alr_envs/open_ai/robotics/fetch/mp_wrapper.py +++ b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Union import numpy as np -from mp_env_api import MPEnvWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(MPEnvWrapper): +class MPWrapper(RawInterfaceWrapper): @property def active_obs(self): diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 9af0a2d..b5587a7 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -1,18 +1,19 @@ import warnings -from typing import Iterable, Type, Union, Mapping, MutableMapping +from copy import deepcopy +from typing import Iterable, Type, Union, MutableMapping import gym import numpy as np -from gym.envs.registration import EnvSpec -from mp_pytorch import MPInterface +from gym.envs.registration import EnvSpec, registry +from gym.wrappers import TimeAwareObservation from alr_envs.mp.basis_generator_factory import get_basis_generator from alr_envs.mp.black_box_wrapper import BlackBoxWrapper -from alr_envs.mp.controllers.base_controller import BaseController from alr_envs.mp.controllers.controller_factory import get_controller from alr_envs.mp.mp_factory import get_trajectory_generator from alr_envs.mp.phase_generator_factory import get_phase_generator from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.utils.utils import nested_update def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): @@ -41,7 +42,15 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa return f if return_callable else f() -def make(env_id: str, seed, **kwargs): +def make(env_id, seed, **kwargs): + spec = registry.get(env_id) + # This access is required to allow for nested dict updates + all_kwargs = deepcopy(spec._kwargs) + nested_update(all_kwargs, **kwargs) + return _make(env_id, seed, **all_kwargs) + + +def _make(env_id: str, seed, **kwargs): """ Converts an env_id to an environment with the gym API. This also works for DeepMind Control Suite interface_wrappers @@ -102,12 +111,12 @@ def _make_wrapped_env( ): """ Helper function for creating a wrapped gym environment using MPs. - It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is + It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is provided to expose the interface for MPs. Args: env_id: name of the environment - wrappers: list of wrappers (at least an MPEnvWrapper), + wrappers: list of wrappers (at least an RawInterfaceWrapper), seed: seed of environment Returns: gym environment with all specified wrappers applied @@ -126,22 +135,20 @@ def _make_wrapped_env( return _env -def make_bb_env( - env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, +def make_bb( + env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1, - sequenced=False, **kwargs): + **kwargs): """ This can also be used standalone for manually building a custom DMP environment. Args: - black_box_wrapper_kwargs: kwargs for the black-box wrapper + black_box_kwargs: kwargs for the black-box wrapper basis_kwargs: kwargs for the basis generator phase_kwargs: kwargs for the phase generator controller_kwargs: kwargs for the tracking controller env_id: base_env_name, wrappers: list of wrappers (at least an BlackBoxWrapper), seed: seed of environment - sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory, - this behavior is much closer to step based learning. traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP Returns: DMP wrapped gym env @@ -150,19 +157,33 @@ def make_bb_env( _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None)) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) - if black_box_wrapper_kwargs.get('duration', None) is None: - black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt - if phase_kwargs.get('tau', None) is None: - phase_kwargs['tau'] = black_box_wrapper_kwargs['duration'] + learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories') + do_replanning = black_box_kwargs.get('replanning_schedule') + if learn_sub_trajs and do_replanning: + raise ValueError('Cannot used sub-trajectory learning and replanning together.') + + if learn_sub_trajs or do_replanning: + # add time_step observation when replanning + kwargs['wrappers'].append(TimeAwareObservation) + traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item()) + if black_box_kwargs.get('duration') is None: + black_box_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt + if phase_kwargs.get('tau') is None: + phase_kwargs['tau'] = black_box_kwargs['duration'] + + if learn_sub_trajs is not None: + # We have to learn the length when learning sub_trajectories trajectories + phase_kwargs['learn_tau'] = True + phase_gen = get_phase_generator(**phase_kwargs) basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs) controller = get_controller(**controller_kwargs) traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs) bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller, - **black_box_wrapper_kwargs) + **black_box_kwargs) return bb_env @@ -204,16 +225,16 @@ def make_bb_env_helper(**kwargs): wrappers = kwargs.pop("wrappers") traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) - black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {}) + black_box_kwargs = kwargs.pop('black_box_kwargs', {}) contr_kwargs = kwargs.pop("controller_kwargs", {}) phase_kwargs = kwargs.pop("phase_generator_kwargs", {}) basis_kwargs = kwargs.pop("basis_generator_kwargs", {}) - return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers, - black_box_wrapper_kwargs=black_box_kwargs, - traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, - phase_kwargs=phase_kwargs, - basis_kwargs=basis_kwargs, **kwargs, seed=seed) + return make_bb(env_id=kwargs.pop("name"), wrappers=wrappers, + black_box_kwargs=black_box_kwargs, + traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, + phase_kwargs=phase_kwargs, + basis_kwargs=basis_kwargs, **kwargs, seed=seed) def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): @@ -224,7 +245,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[ It can be found in the BaseMP class. Args: - mp_time_limit: max trajectory length of trajectory_generator in seconds + mp_time_limit: max trajectory length of traj_gen in seconds env_time_limit: max trajectory length of DMC environment in seconds Returns: diff --git a/alr_envs/utils/utils.py b/alr_envs/utils/utils.py index b90cf60..b212aac 100644 --- a/alr_envs/utils/utils.py +++ b/alr_envs/utils/utils.py @@ -1,3 +1,5 @@ +from collections import Mapping, MutableMapping + import numpy as np import torch as ch @@ -23,4 +25,24 @@ def angle_normalize(x, type="deg"): def get_numpy(x: ch.Tensor): + """ + Returns numpy array from torch tensor + Args: + x: + + Returns: + + """ return x.detach().cpu().numpy() + + +def nested_update(base: MutableMapping, update): + """ + Updated method for nested Mappings + Args: + base: main Mapping to be updated + update: updated values for base Mapping + + """ + for k, v in update.items(): + base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v