updated for new mp_pytorch
This commit is contained in:
parent
8f07770a2f
commit
5c8ba41e04
@ -100,15 +100,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
|
||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||
|
||||
# if self.do_replanning:
|
||||
# # Remove first part of trajectory as this is already over
|
||||
# position = position[self.current_traj_steps:]
|
||||
# velocity = velocity[self.current_traj_steps:]
|
||||
|
||||
return position, velocity
|
||||
|
||||
def _get_traj_gen_action_space(self):
|
||||
@ -179,8 +174,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
|
||||
self.condition_pos = pos if self.condition_on_desired else None
|
||||
self.condition_vel = vel if self.condition_on_desired else None
|
||||
if self.condition_on_desired:
|
||||
self.condition_pos = pos
|
||||
self.condition_vel = vel
|
||||
|
||||
break
|
||||
|
||||
@ -207,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.current_traj_steps = 0
|
||||
self.plan_steps = 0
|
||||
self.traj_gen.reset()
|
||||
self.condition_vel = None
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
return super(BlackBoxWrapper, self).reset()
|
||||
|
@ -26,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +53,11 @@ DEFAULT_BB_DICT_ProDMP = {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 3,
|
||||
'alpha': 10
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||
|
Loading…
Reference in New Issue
Block a user