diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index b92c653..fd5032b 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -62,7 +62,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.return_context_observation = not (learn_sub_trajectories or self.do_replanning) self.traj_gen_action_space = self._get_traj_gen_action_space() self.action_space = self._get_action_space() - self.observation_space = self._get_observation_space() # rendering @@ -95,23 +94,16 @@ class BlackBoxWrapper(gym.ObservationWrapper): clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high) self.traj_gen.set_params(clipped_params) init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) - # TODO we could think about initializing with the previous desired value in order to have a smooth transition - # at least from the planning point of view. condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel) self.traj_gen.set_duration(duration, self.dt) - # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) + position = get_numpy(self.traj_gen.get_traj_pos()) velocity = get_numpy(self.traj_gen.get_traj_vel()) - # if self.do_replanning: - # # Remove first part of trajectory as this is already over - # position = position[self.current_traj_steps:] - # velocity = velocity[self.current_traj_steps:] - return position, velocity def _get_traj_gen_action_space(self): @@ -182,12 +174,13 @@ class BlackBoxWrapper(gym.ObservationWrapper): t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): - self.condition_pos = pos if self.condition_on_desired else None - self.condition_vel = vel if self.condition_on_desired else None + if self.condition_on_desired: + self.condition_pos = pos + self.condition_vel = vel break - infos.update({k: v[:t+1] for k, v in infos.items()}) + infos.update({k: v[:t + 1] for k, v in infos.items()}) self.current_traj_steps += t + 1 if self.verbose >= 2: @@ -210,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.current_traj_steps = 0 self.plan_steps = 0 self.traj_gen.reset() - self.condition_vel = None self.condition_pos = None + self.condition_vel = None return super(BlackBoxWrapper, self).reset() diff --git a/fancy_gym/meta/__init__.py b/fancy_gym/meta/__init__.py index b9f0dca..63b15c2 100644 --- a/fancy_gym/meta/__init__.py +++ b/fancy_gym/meta/__init__.py @@ -13,7 +13,8 @@ DEFAULT_BB_DICT_ProMP = { "name": 'EnvName', "wrappers": [], "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'promp' + 'trajectory_generator_type': 'promp', + 'weights_scale': 10, }, "phase_generator_kwargs": { 'phase_generator_type': 'linear' @@ -25,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = { 'basis_generator_type': 'zero_rbf', 'num_basis': 5, 'num_basis_zero_start': 1 + }, + 'black_box_kwargs': { + 'condition_on_desired': False, } } @@ -32,22 +36,28 @@ DEFAULT_BB_DICT_ProDMP = { "name": 'EnvName', "wrappers": [], "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'prodmp' + 'trajectory_generator_type': 'prodmp', + 'auto_scale_basis': True, + 'weights_scale': 10, + # 'goal_scale': 0., + 'disable_goal': True, }, "phase_generator_kwargs": { - 'phase_generator_type': 'exp' + 'phase_generator_type': 'exp', + # 'alpha_phase' : 3, }, "controller_kwargs": { 'controller_type': 'metaworld', }, "basis_generator_kwargs": { 'basis_generator_type': 'prodmp', - 'num_basis': 5 + 'num_basis': 3, + 'alpha': 10 }, - "black_box_kwargs": { - 'replanning_schedule': None, - 'max_planning_times': None, + 'black_box_kwargs': { + 'condition_on_desired': False, } + } _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2", @@ -152,7 +162,6 @@ for _task in _goal_and_object_change_envs: ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) - _goal_and_endeffector_change_envs = ["basketball-v2"] for _task in _goal_and_endeffector_change_envs: task_id_split = _task.split("-") diff --git a/fancy_gym/meta/base_metaworld_mp_wrapper.py b/fancy_gym/meta/base_metaworld_mp_wrapper.py index 4029e28..0f1a9a9 100644 --- a/fancy_gym/meta/base_metaworld_mp_wrapper.py +++ b/fancy_gym/meta/base_metaworld_mp_wrapper.py @@ -9,12 +9,9 @@ class BaseMetaworldMPWrapper(RawInterfaceWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray]: r_close = self.env.data.get_joint_qpos("r_close") - # TODO check if this is correct - # return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close]) return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close]) @property def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - # TODO check if this is correct return np.zeros(4, ) # raise NotImplementedError("Velocity cannot be retrieved.") diff --git a/test/test_metaworld_envs.py b/test/test_metaworld_envs.py index 768958d..ed300f4 100644 --- a/test/test_metaworld_envs.py +++ b/test/test_metaworld_envs.py @@ -9,7 +9,6 @@ from test.utils import run_env, run_env_determinism METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()] METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) -print(METAWORLD_MP_IDS) SEED = 1