use desired point as boundary condition
This commit is contained in:
parent
556bfd0b35
commit
a1d96e6016
@ -21,7 +21,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
learn_sub_trajectories: bool = False,
|
||||
replanning_schedule: Optional[
|
||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum
|
||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||
desired_conditioning: bool = False
|
||||
):
|
||||
"""
|
||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||
@ -67,6 +68,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.render_kwargs = {}
|
||||
self.verbose = verbose
|
||||
|
||||
# condition value
|
||||
self.desired_conditioning = True
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
|
||||
def observation(self, observation):
|
||||
# return context space if we are
|
||||
if self.return_context_observation:
|
||||
@ -87,7 +93,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||
# at least from the planning point of view.
|
||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||
# self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||
if self.current_traj_steps == 0:
|
||||
self.condition_pos = self.current_pos
|
||||
self.condition_vel = self.current_vel
|
||||
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||
@ -165,14 +175,22 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps):
|
||||
if self.desired_conditioning:
|
||||
self.condition_pos = pos
|
||||
self.condition_vel = vel
|
||||
else:
|
||||
self.condition_pos = self.current_pos
|
||||
self.condition_vel = self.current_vel
|
||||
break
|
||||
|
||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||
self.current_traj_steps += t + 1
|
||||
|
||||
if self.verbose >= 2:
|
||||
infos['positions'] = position
|
||||
infos['velocities'] = velocity
|
||||
infos['desired_pos'] = position[:t+1]
|
||||
infos['desired_vel'] = velocity[:t+1]
|
||||
infos['current_pos'] = self.current_pos
|
||||
infos['current_vel'] = self.current_vel
|
||||
infos['step_actions'] = actions[:t + 1]
|
||||
infos['step_observations'] = observations[:t + 1]
|
||||
infos['step_rewards'] = rewards[:t + 1]
|
||||
|
@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
||||
if __name__=="__main__":
|
||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
||||
env.reset()
|
||||
for i in range(1):
|
||||
for i in range(10):
|
||||
env.reset()
|
||||
for _ in range(100):
|
||||
env.render("human")
|
||||
|
38
fancy_gym/examples/example_replanning_envs.py
Normal file
38
fancy_gym/examples/example_replanning_envs.py
Normal file
@ -0,0 +1,38 @@
|
||||
import fancy_gym
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_trajectory(traj):
|
||||
plt.figure()
|
||||
plt.plot(traj[:, 3])
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def run_replanning_envs(env_name="BoxPushingProDMP-v0", seed=1, iterations=1, render=True):
|
||||
env = fancy_gym.make(env_name, seed=seed)
|
||||
env.reset()
|
||||
for i in range(iterations):
|
||||
done = False
|
||||
desired_pos_traj = np.zeros((100, 7))
|
||||
desired_vel_traj = np.zeros((100, 7))
|
||||
real_pos_traj = np.zeros((100, 7))
|
||||
real_vel_traj = np.zeros((100, 7))
|
||||
t = 0
|
||||
while done is False:
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
desired_pos_traj[t: t + 25, :] = info['desired_pos']
|
||||
desired_vel_traj[t: t + 25, :] = info['desired_vel']
|
||||
# real_pos_traj.append(info['current_pos'])
|
||||
# real_vel_traj.append(info['current_vel'])
|
||||
t += 25
|
||||
if render:
|
||||
env.render(mode="human")
|
||||
if done:
|
||||
env.reset()
|
||||
plot_trajectory(desired_pos_traj)
|
||||
env.close()
|
||||
del env
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_replanning_envs(env_name="BoxPushingDenseProDMP-v0", seed=1, iterations=1, render=False)
|
Loading…
Reference in New Issue
Block a user