use desired point as boundary condition
This commit is contained in:
parent
556bfd0b35
commit
a1d96e6016
@ -21,7 +21,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
learn_sub_trajectories: bool = False,
|
learn_sub_trajectories: bool = False,
|
||||||
replanning_schedule: Optional[
|
replanning_schedule: Optional[
|
||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
|
desired_conditioning: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||||
@ -67,6 +68,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.render_kwargs = {}
|
self.render_kwargs = {}
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# condition value
|
||||||
|
self.desired_conditioning = True
|
||||||
|
self.condition_pos = None
|
||||||
|
self.condition_vel = None
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
if self.return_context_observation:
|
if self.return_context_observation:
|
||||||
@ -87,7 +93,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||||
# at least from the planning point of view.
|
# at least from the planning point of view.
|
||||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
# self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||||
|
if self.current_traj_steps == 0:
|
||||||
|
self.condition_pos = self.current_pos
|
||||||
|
self.condition_vel = self.current_vel
|
||||||
|
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
@ -165,14 +175,22 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps):
|
||||||
|
if self.desired_conditioning:
|
||||||
|
self.condition_pos = pos
|
||||||
|
self.condition_vel = vel
|
||||||
|
else:
|
||||||
|
self.condition_pos = self.current_pos
|
||||||
|
self.condition_vel = self.current_vel
|
||||||
break
|
break
|
||||||
|
|
||||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||||
self.current_traj_steps += t + 1
|
self.current_traj_steps += t + 1
|
||||||
|
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
infos['positions'] = position
|
infos['desired_pos'] = position[:t+1]
|
||||||
infos['velocities'] = velocity
|
infos['desired_vel'] = velocity[:t+1]
|
||||||
|
infos['current_pos'] = self.current_pos
|
||||||
|
infos['current_vel'] = self.current_vel
|
||||||
infos['step_actions'] = actions[:t + 1]
|
infos['step_actions'] = actions[:t + 1]
|
||||||
infos['step_observations'] = observations[:t + 1]
|
infos['step_observations'] = observations[:t + 1]
|
||||||
infos['step_rewards'] = rewards[:t + 1]
|
infos['step_rewards'] = rewards[:t + 1]
|
||||||
|
@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
|||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
||||||
env.reset()
|
env.reset()
|
||||||
for i in range(1):
|
for i in range(10):
|
||||||
env.reset()
|
env.reset()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
env.render("human")
|
env.render("human")
|
||||||
|
38
fancy_gym/examples/example_replanning_envs.py
Normal file
38
fancy_gym/examples/example_replanning_envs.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import fancy_gym
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def plot_trajectory(traj):
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(traj[:, 3])
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def run_replanning_envs(env_name="BoxPushingProDMP-v0", seed=1, iterations=1, render=True):
|
||||||
|
env = fancy_gym.make(env_name, seed=seed)
|
||||||
|
env.reset()
|
||||||
|
for i in range(iterations):
|
||||||
|
done = False
|
||||||
|
desired_pos_traj = np.zeros((100, 7))
|
||||||
|
desired_vel_traj = np.zeros((100, 7))
|
||||||
|
real_pos_traj = np.zeros((100, 7))
|
||||||
|
real_vel_traj = np.zeros((100, 7))
|
||||||
|
t = 0
|
||||||
|
while done is False:
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
desired_pos_traj[t: t + 25, :] = info['desired_pos']
|
||||||
|
desired_vel_traj[t: t + 25, :] = info['desired_vel']
|
||||||
|
# real_pos_traj.append(info['current_pos'])
|
||||||
|
# real_vel_traj.append(info['current_vel'])
|
||||||
|
t += 25
|
||||||
|
if render:
|
||||||
|
env.render(mode="human")
|
||||||
|
if done:
|
||||||
|
env.reset()
|
||||||
|
plot_trajectory(desired_pos_traj)
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_replanning_envs(env_name="BoxPushingDenseProDMP-v0", seed=1, iterations=1, render=False)
|
Loading…
Reference in New Issue
Block a user