updated for new mp_pytorch
This commit is contained in:
parent
8f07770a2f
commit
5c8ba41e04
@ -100,15 +100,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
|
||||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||||
|
|
||||||
# if self.do_replanning:
|
|
||||||
# # Remove first part of trajectory as this is already over
|
|
||||||
# position = position[self.current_traj_steps:]
|
|
||||||
# velocity = velocity[self.current_traj_steps:]
|
|
||||||
|
|
||||||
return position, velocity
|
return position, velocity
|
||||||
|
|
||||||
def _get_traj_gen_action_space(self):
|
def _get_traj_gen_action_space(self):
|
||||||
@ -179,8 +174,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
t + 1 + self.current_traj_steps)
|
t + 1 + self.current_traj_steps)
|
||||||
and self.plan_steps < self.max_planning_times):
|
and self.plan_steps < self.max_planning_times):
|
||||||
|
|
||||||
self.condition_pos = pos if self.condition_on_desired else None
|
if self.condition_on_desired:
|
||||||
self.condition_vel = vel if self.condition_on_desired else None
|
self.condition_pos = pos
|
||||||
|
self.condition_vel = vel
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -207,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
self.plan_steps = 0
|
self.plan_steps = 0
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
self.condition_vel = None
|
|
||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
|
self.condition_vel = None
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
@ -26,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
'basis_generator_type': 'zero_rbf',
|
'basis_generator_type': 'zero_rbf',
|
||||||
'num_basis': 5,
|
'num_basis': 5,
|
||||||
'num_basis_zero_start': 1
|
'num_basis_zero_start': 1
|
||||||
|
},
|
||||||
|
'black_box_kwargs': {
|
||||||
|
'condition_on_desired': False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +53,11 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
'basis_generator_type': 'prodmp',
|
'basis_generator_type': 'prodmp',
|
||||||
'num_basis': 3,
|
'num_basis': 3,
|
||||||
'alpha': 10
|
'alpha': 10
|
||||||
|
},
|
||||||
|
'black_box_kwargs': {
|
||||||
|
'condition_on_desired': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
|
Loading…
Reference in New Issue
Block a user