diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 4c90512..86491b5 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -24,13 +24,17 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} DEFAULT_MP_ENV_DICT = { "name": 'EnvName', "wrappers": [], + # TODO move scale to traj_gen "ep_wrapper_kwargs": { "weight_scale": 1 }, + # TODO traj_gen_kwargs + # TODO remove action_dim "movement_primitives_kwargs": { 'movement_primitives_type': 'promp', 'action_dim': 7 }, + # TODO remove tau "phase_generator_kwargs": { 'phase_generator_type': 'linear', 'delay': 0, @@ -40,13 +44,13 @@ DEFAULT_MP_ENV_DICT = { }, "controller_kwargs": { 'controller_type': 'motor', - "p_gains": np.ones(7), - "d_gains": np.ones(7) * 0.1, + "p_gains": 1.0, + "d_gains": 0.1, }, "basis_generator_kwargs": { 'basis_generator_type': 'zero_rbf', 'num_basis': 5, - 'num_basis_zero_start': 2 + 'num_basis_zero_start': 2 # TODO: Change to 1 } } diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py index f098c2d..d14c9a9 100644 --- a/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py @@ -2,8 +2,6 @@ from typing import Tuple, Union import numpy as np -from mp_env_api import MPEnvWrapper - from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper diff --git a/alr_envs/mp/black_box_wrapper.py b/alr_envs/mp/black_box_wrapper.py index 5ae0ff9..d87c332 100644 --- a/alr_envs/mp/black_box_wrapper.py +++ b/alr_envs/mp/black_box_wrapper.py @@ -32,18 +32,24 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): self.env = env self.duration = duration - self.traj_steps = int(duration / self.dt) - self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps + self.sequencing = sequencing + # self.traj_steps = int(duration / self.dt) + # self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps # duration = self.env.max_episode_steps * self.dt # trajectory generation self.trajectory_generator = trajectory_generator self.tracking_controller = tracking_controller # self.weight_scale = weight_scale - self.time_steps = np.linspace(0, self.duration, self.traj_steps) - self.trajectory_generator.set_mp_times(self.time_steps) - # self.trajectory_generator.set_mp_duration(self.time_steps, dt) - # action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params))) + # self.time_steps = np.linspace(0, self.duration, self.traj_steps) + # self.trajectory_generator.set_mp_times(self.time_steps) + if not sequencing: + self.trajectory_generator.set_mp_duration(np.array([self.duration]), np.array([self.dt])) + else: + # sequencing stuff + pass + + # reward computation self.reward_aggregation = reward_aggregation # spaces @@ -67,15 +73,12 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): return observation[self.env.context_mask] def get_trajectory(self, action: np.ndarray) -> Tuple: - # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at - # the beginning of the array. - # ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay) - # scaled_mp_params = action.copy() - # scaled_mp_params[ignore_indices:] *= self.weight_scale - clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high) self.trajectory_generator.set_params(clipped_params) - self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, + # if self.trajectory_generator.learn_tau: + # self.trajectory_generator.set_mp_duration(self.trajectory_generator.tau, np.array([self.dt])) + self.trajectory_generator.set_mp_duration(None if self.sequencing else self.duration, np.array([self.dt])) + self.trajectory_generator.set_boundary_conditions(bc_time=, bc_pos=self.current_pos, bc_vel=self.current_vel) traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True) trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] @@ -152,7 +155,7 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC): if self.render_mode is not None: self.render(mode=self.render_mode, **self.render_kwargs) - if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t): + if done or self.env.do_replanning(self.current_pos, self.current_vel, obs, c_action, t + past_steps): break infos.update({k: v[:t + 1] for k, v in infos.items()}) diff --git a/alr_envs/mp/raw_interface_wrapper.py b/alr_envs/mp/raw_interface_wrapper.py index 45d5daf..d57ff9a 100644 --- a/alr_envs/mp/raw_interface_wrapper.py +++ b/alr_envs/mp/raw_interface_wrapper.py @@ -43,13 +43,13 @@ class RawInterfaceWrapper(gym.Wrapper): raise NotImplementedError() @property - @abstractmethod def dt(self) -> float: """ Control frequency of the environment Returns: float """ + return self.env.dt def do_replanning(self, pos, vel, s, a, t): # return t % 100 == 0