Merge pull request #64 from ALRhub/Add-ProDMP-envs
Add prodmp metaworld envs
This commit is contained in:
commit
8948505f06
@ -62,7 +62,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||
self.action_space = self._get_action_space()
|
||||
|
||||
self.observation_space = self._get_observation_space()
|
||||
|
||||
# rendering
|
||||
@ -95,23 +94,16 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||
self.traj_gen.set_params(clipped_params)
|
||||
init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||
# at least from the planning point of view.
|
||||
|
||||
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
||||
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
||||
|
||||
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
|
||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||
|
||||
# if self.do_replanning:
|
||||
# # Remove first part of trajectory as this is already over
|
||||
# position = position[self.current_traj_steps:]
|
||||
# velocity = velocity[self.current_traj_steps:]
|
||||
|
||||
return position, velocity
|
||||
|
||||
def _get_traj_gen_action_space(self):
|
||||
@ -182,8 +174,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
|
||||
self.condition_pos = pos if self.condition_on_desired else None
|
||||
self.condition_vel = vel if self.condition_on_desired else None
|
||||
if self.condition_on_desired:
|
||||
self.condition_pos = pos
|
||||
self.condition_vel = vel
|
||||
|
||||
break
|
||||
|
||||
@ -210,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.current_traj_steps = 0
|
||||
self.plan_steps = 0
|
||||
self.traj_gen.reset()
|
||||
self.condition_vel = None
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
return super(BlackBoxWrapper, self).reset()
|
||||
|
@ -13,7 +13,8 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'promp'
|
||||
'trajectory_generator_type': 'promp',
|
||||
'weights_scale': 10,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'linear'
|
||||
@ -25,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,22 +36,28 @@ DEFAULT_BB_DICT_ProDMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'prodmp'
|
||||
'trajectory_generator_type': 'prodmp',
|
||||
'auto_scale_basis': True,
|
||||
'weights_scale': 10,
|
||||
# 'goal_scale': 0.,
|
||||
'disable_goal': True,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp'
|
||||
'phase_generator_type': 'exp',
|
||||
# 'alpha_phase' : 3,
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 5
|
||||
'num_basis': 3,
|
||||
'alpha': 10
|
||||
},
|
||||
"black_box_kwargs": {
|
||||
'replanning_schedule': None,
|
||||
'max_planning_times': None,
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||
@ -152,7 +162,6 @@ for _task in _goal_and_object_change_envs:
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
|
||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||
for _task in _goal_and_endeffector_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
|
@ -9,12 +9,9 @@ class BaseMetaworldMPWrapper(RawInterfaceWrapper):
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
r_close = self.env.data.get_joint_qpos("r_close")
|
||||
# TODO check if this is correct
|
||||
# return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close])
|
||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
# TODO check if this is correct
|
||||
return np.zeros(4, )
|
||||
# raise NotImplementedError("Velocity cannot be retrieved.")
|
||||
|
@ -9,7 +9,6 @@ from test.utils import run_env, run_env_determinism
|
||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
||||
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
print(METAWORLD_MP_IDS)
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user