Merge pull request #64 from ALRhub/Add-ProDMP-envs
Add prodmp metaworld envs
This commit is contained in:
commit
8948505f06
@ -62,7 +62,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||||
self.action_space = self._get_action_space()
|
self.action_space = self._get_action_space()
|
||||||
|
|
||||||
self.observation_space = self._get_observation_space()
|
self.observation_space = self._get_observation_space()
|
||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
@ -95,23 +94,16 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||||
self.traj_gen.set_params(clipped_params)
|
self.traj_gen.set_params(clipped_params)
|
||||||
init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
|
||||||
# at least from the planning point of view.
|
|
||||||
|
|
||||||
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
||||||
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
||||||
|
|
||||||
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
|
||||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||||
|
|
||||||
# if self.do_replanning:
|
|
||||||
# # Remove first part of trajectory as this is already over
|
|
||||||
# position = position[self.current_traj_steps:]
|
|
||||||
# velocity = velocity[self.current_traj_steps:]
|
|
||||||
|
|
||||||
return position, velocity
|
return position, velocity
|
||||||
|
|
||||||
def _get_traj_gen_action_space(self):
|
def _get_traj_gen_action_space(self):
|
||||||
@ -182,12 +174,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
t + 1 + self.current_traj_steps)
|
t + 1 + self.current_traj_steps)
|
||||||
and self.plan_steps < self.max_planning_times):
|
and self.plan_steps < self.max_planning_times):
|
||||||
|
|
||||||
self.condition_pos = pos if self.condition_on_desired else None
|
if self.condition_on_desired:
|
||||||
self.condition_vel = vel if self.condition_on_desired else None
|
self.condition_pos = pos
|
||||||
|
self.condition_vel = vel
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||||
self.current_traj_steps += t + 1
|
self.current_traj_steps += t + 1
|
||||||
|
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
@ -210,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
self.plan_steps = 0
|
self.plan_steps = 0
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
self.condition_vel = None
|
|
||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
|
self.condition_vel = None
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
@ -13,7 +13,8 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
"wrappers": [],
|
"wrappers": [],
|
||||||
"trajectory_generator_kwargs": {
|
"trajectory_generator_kwargs": {
|
||||||
'trajectory_generator_type': 'promp'
|
'trajectory_generator_type': 'promp',
|
||||||
|
'weights_scale': 10,
|
||||||
},
|
},
|
||||||
"phase_generator_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
'phase_generator_type': 'linear'
|
'phase_generator_type': 'linear'
|
||||||
@ -25,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
'basis_generator_type': 'zero_rbf',
|
'basis_generator_type': 'zero_rbf',
|
||||||
'num_basis': 5,
|
'num_basis': 5,
|
||||||
'num_basis_zero_start': 1
|
'num_basis_zero_start': 1
|
||||||
|
},
|
||||||
|
'black_box_kwargs': {
|
||||||
|
'condition_on_desired': False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,22 +36,28 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
"wrappers": [],
|
"wrappers": [],
|
||||||
"trajectory_generator_kwargs": {
|
"trajectory_generator_kwargs": {
|
||||||
'trajectory_generator_type': 'prodmp'
|
'trajectory_generator_type': 'prodmp',
|
||||||
|
'auto_scale_basis': True,
|
||||||
|
'weights_scale': 10,
|
||||||
|
# 'goal_scale': 0.,
|
||||||
|
'disable_goal': True,
|
||||||
},
|
},
|
||||||
"phase_generator_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
'phase_generator_type': 'exp'
|
'phase_generator_type': 'exp',
|
||||||
|
# 'alpha_phase' : 3,
|
||||||
},
|
},
|
||||||
"controller_kwargs": {
|
"controller_kwargs": {
|
||||||
'controller_type': 'metaworld',
|
'controller_type': 'metaworld',
|
||||||
},
|
},
|
||||||
"basis_generator_kwargs": {
|
"basis_generator_kwargs": {
|
||||||
'basis_generator_type': 'prodmp',
|
'basis_generator_type': 'prodmp',
|
||||||
'num_basis': 5
|
'num_basis': 3,
|
||||||
|
'alpha': 10
|
||||||
},
|
},
|
||||||
"black_box_kwargs": {
|
'black_box_kwargs': {
|
||||||
'replanning_schedule': None,
|
'condition_on_desired': False,
|
||||||
'max_planning_times': None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
@ -152,7 +162,6 @@ for _task in _goal_and_object_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
|
|
||||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
for _task in _goal_and_endeffector_change_envs:
|
for _task in _goal_and_endeffector_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
|
@ -9,12 +9,9 @@ class BaseMetaworldMPWrapper(RawInterfaceWrapper):
|
|||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
r_close = self.env.data.get_joint_qpos("r_close")
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
# TODO check if this is correct
|
|
||||||
# return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close])
|
|
||||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
# TODO check if this is correct
|
|
||||||
return np.zeros(4, )
|
return np.zeros(4, )
|
||||||
# raise NotImplementedError("Velocity cannot be retrieved.")
|
# raise NotImplementedError("Velocity cannot be retrieved.")
|
||||||
|
@ -9,7 +9,6 @@ from test.utils import run_env, run_env_determinism
|
|||||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
||||||
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
print(METAWORLD_MP_IDS)
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user