use desired point as boundary condition

This commit is contained in:
Hongyi Zhou 2022-10-25 22:15:30 +02:00
parent 556bfd0b35
commit a1d96e6016
3 changed files with 61 additions and 5 deletions

View File

@ -21,7 +21,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
learn_sub_trajectories: bool = False, learn_sub_trajectories: bool = False,
replanning_schedule: Optional[ replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum reward_aggregation: Callable[[np.ndarray], float] = np.sum,
desired_conditioning: bool = False
): ):
""" """
gym.Wrapper for leveraging a black box approach with a trajectory generator. gym.Wrapper for leveraging a black box approach with a trajectory generator.
@ -67,6 +68,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.render_kwargs = {} self.render_kwargs = {}
self.verbose = verbose self.verbose = verbose
# condition value
self.desired_conditioning = True
self.condition_pos = None
self.condition_vel = None
def observation(self, observation): def observation(self, observation):
# return context space if we are # return context space if we are
if self.return_context_observation: if self.return_context_observation:
@ -87,7 +93,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
# TODO we could think about initializing with the previous desired value in order to have a smooth transition # TODO we could think about initializing with the previous desired value in order to have a smooth transition
# at least from the planning point of view. # at least from the planning point of view.
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) # self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
if self.current_traj_steps == 0:
self.condition_pos = self.current_pos
self.condition_vel = self.current_vel
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
self.traj_gen.set_duration(duration, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos()) position = get_numpy(self.traj_gen.get_traj_pos())
@ -165,14 +175,22 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps): t + 1 + self.current_traj_steps):
if self.desired_conditioning:
self.condition_pos = pos
self.condition_vel = vel
else:
self.condition_pos = self.current_pos
self.condition_vel = self.current_vel
break break
infos.update({k: v[:t+1] for k, v in infos.items()}) infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1 self.current_traj_steps += t + 1
if self.verbose >= 2: if self.verbose >= 2:
infos['positions'] = position infos['desired_pos'] = position[:t+1]
infos['velocities'] = velocity infos['desired_vel'] = velocity[:t+1]
infos['current_pos'] = self.current_pos
infos['current_vel'] = self.current_vel
infos['step_actions'] = actions[:t + 1] infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1] infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1] infos['step_rewards'] = rewards[:t + 1]

View File

@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
if __name__=="__main__": if __name__=="__main__":
env = BoxPushingTemporalSpatialSparse(frame_skip=10) env = BoxPushingTemporalSpatialSparse(frame_skip=10)
env.reset() env.reset()
for i in range(1): for i in range(10):
env.reset() env.reset()
for _ in range(100): for _ in range(100):
env.render("human") env.render("human")

View File

@ -0,0 +1,38 @@
import fancy_gym
import numpy as np
import matplotlib.pyplot as plt
def plot_trajectory(traj):
plt.figure()
plt.plot(traj[:, 3])
plt.legend()
plt.show()
def run_replanning_envs(env_name="BoxPushingProDMP-v0", seed=1, iterations=1, render=True):
env = fancy_gym.make(env_name, seed=seed)
env.reset()
for i in range(iterations):
done = False
desired_pos_traj = np.zeros((100, 7))
desired_vel_traj = np.zeros((100, 7))
real_pos_traj = np.zeros((100, 7))
real_vel_traj = np.zeros((100, 7))
t = 0
while done is False:
ac = env.action_space.sample()
obs, reward, done, info = env.step(ac)
desired_pos_traj[t: t + 25, :] = info['desired_pos']
desired_vel_traj[t: t + 25, :] = info['desired_vel']
# real_pos_traj.append(info['current_pos'])
# real_vel_traj.append(info['current_vel'])
t += 25
if render:
env.render(mode="human")
if done:
env.reset()
plot_trajectory(desired_pos_traj)
env.close()
del env
if __name__ == "__main__":
run_replanning_envs(env_name="BoxPushingDenseProDMP-v0", seed=1, iterations=1, render=False)