updated for new mp_pytorch

This commit is contained in:
Fabian 2023-03-21 15:27:11 +01:00
parent 8f07770a2f
commit 5c8ba41e04
2 changed files with 12 additions and 9 deletions

View File

@ -100,15 +100,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel) self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
self.traj_gen.set_duration(duration, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos()) position = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
# if self.do_replanning:
# # Remove first part of trajectory as this is already over
# position = position[self.current_traj_steps:]
# velocity = velocity[self.current_traj_steps:]
return position, velocity return position, velocity
def _get_traj_gen_action_space(self): def _get_traj_gen_action_space(self):
@ -179,8 +174,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
t + 1 + self.current_traj_steps) t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times): and self.plan_steps < self.max_planning_times):
self.condition_pos = pos if self.condition_on_desired else None if self.condition_on_desired:
self.condition_vel = vel if self.condition_on_desired else None self.condition_pos = pos
self.condition_vel = vel
break break
@ -207,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.current_traj_steps = 0 self.current_traj_steps = 0
self.plan_steps = 0 self.plan_steps = 0
self.traj_gen.reset() self.traj_gen.reset()
self.condition_vel = None
self.condition_pos = None self.condition_pos = None
self.condition_vel = None
return super(BlackBoxWrapper, self).reset() return super(BlackBoxWrapper, self).reset()

View File

@ -26,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
'basis_generator_type': 'zero_rbf', 'basis_generator_type': 'zero_rbf',
'num_basis': 5, 'num_basis': 5,
'num_basis_zero_start': 1 'num_basis_zero_start': 1
},
'black_box_kwargs': {
'condition_on_desired': False,
} }
} }
@ -50,7 +53,11 @@ DEFAULT_BB_DICT_ProDMP = {
'basis_generator_type': 'prodmp', 'basis_generator_type': 'prodmp',
'num_basis': 3, 'num_basis': 3,
'alpha': 10 'alpha': 10
},
'black_box_kwargs': {
'condition_on_desired': False,
} }
} }
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2", _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",