Merge pull request #64 from ALRhub/Add-ProDMP-envs

Add prodmp metaworld envs
This commit is contained in:
ottofabian 2023-03-21 15:29:31 +01:00 committed by GitHub
commit 8948505f06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 25 deletions

View File

@ -62,7 +62,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning) self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
self.traj_gen_action_space = self._get_traj_gen_action_space() self.traj_gen_action_space = self._get_traj_gen_action_space()
self.action_space = self._get_action_space() self.action_space = self._get_action_space()
self.observation_space = self._get_observation_space() self.observation_space = self._get_observation_space()
# rendering # rendering
@ -95,23 +94,16 @@ class BlackBoxWrapper(gym.ObservationWrapper):
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high) clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
self.traj_gen.set_params(clipped_params) self.traj_gen.set_params(clipped_params)
init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
# at least from the planning point of view.
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel) self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
self.traj_gen.set_duration(duration, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos()) position = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
# if self.do_replanning:
# # Remove first part of trajectory as this is already over
# position = position[self.current_traj_steps:]
# velocity = velocity[self.current_traj_steps:]
return position, velocity return position, velocity
def _get_traj_gen_action_space(self): def _get_traj_gen_action_space(self):
@ -182,12 +174,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
t + 1 + self.current_traj_steps) t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times): and self.plan_steps < self.max_planning_times):
self.condition_pos = pos if self.condition_on_desired else None if self.condition_on_desired:
self.condition_vel = vel if self.condition_on_desired else None self.condition_pos = pos
self.condition_vel = vel
break break
infos.update({k: v[:t+1] for k, v in infos.items()}) infos.update({k: v[:t + 1] for k, v in infos.items()})
self.current_traj_steps += t + 1 self.current_traj_steps += t + 1
if self.verbose >= 2: if self.verbose >= 2:
@ -210,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.current_traj_steps = 0 self.current_traj_steps = 0
self.plan_steps = 0 self.plan_steps = 0
self.traj_gen.reset() self.traj_gen.reset()
self.condition_vel = None
self.condition_pos = None self.condition_pos = None
self.condition_vel = None
return super(BlackBoxWrapper, self).reset() return super(BlackBoxWrapper, self).reset()

View File

@ -13,7 +13,8 @@ DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName', "name": 'EnvName',
"wrappers": [], "wrappers": [],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp' 'trajectory_generator_type': 'promp',
'weights_scale': 10,
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
'phase_generator_type': 'linear' 'phase_generator_type': 'linear'
@ -25,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
'basis_generator_type': 'zero_rbf', 'basis_generator_type': 'zero_rbf',
'num_basis': 5, 'num_basis': 5,
'num_basis_zero_start': 1 'num_basis_zero_start': 1
},
'black_box_kwargs': {
'condition_on_desired': False,
} }
} }
@ -32,22 +36,28 @@ DEFAULT_BB_DICT_ProDMP = {
"name": 'EnvName', "name": 'EnvName',
"wrappers": [], "wrappers": [],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp' 'trajectory_generator_type': 'prodmp',
'auto_scale_basis': True,
'weights_scale': 10,
# 'goal_scale': 0.,
'disable_goal': True,
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
'phase_generator_type': 'exp' 'phase_generator_type': 'exp',
# 'alpha_phase' : 3,
}, },
"controller_kwargs": { "controller_kwargs": {
'controller_type': 'metaworld', 'controller_type': 'metaworld',
}, },
"basis_generator_kwargs": { "basis_generator_kwargs": {
'basis_generator_type': 'prodmp', 'basis_generator_type': 'prodmp',
'num_basis': 5 'num_basis': 3,
'alpha': 10
}, },
"black_box_kwargs": { 'black_box_kwargs': {
'replanning_schedule': None, 'condition_on_desired': False,
'max_planning_times': None,
} }
} }
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2", _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
@ -152,7 +162,6 @@ for _task in _goal_and_object_change_envs:
) )
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
_goal_and_endeffector_change_envs = ["basketball-v2"] _goal_and_endeffector_change_envs = ["basketball-v2"]
for _task in _goal_and_endeffector_change_envs: for _task in _goal_and_endeffector_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")

View File

@ -9,12 +9,9 @@ class BaseMetaworldMPWrapper(RawInterfaceWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray]: def current_pos(self) -> Union[float, int, np.ndarray]:
r_close = self.env.data.get_joint_qpos("r_close") r_close = self.env.data.get_joint_qpos("r_close")
# TODO check if this is correct
# return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close])
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close]) return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
@property @property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
# TODO check if this is correct
return np.zeros(4, ) return np.zeros(4, )
# raise NotImplementedError("Velocity cannot be retrieved.") # raise NotImplementedError("Velocity cannot be retrieved.")

View File

@ -9,7 +9,6 @@ from test.utils import run_env, run_env_determinism
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()] ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
print(METAWORLD_MP_IDS)
SEED = 1 SEED = 1