updated for new set_duration function.

This commit is contained in:
Fabian 2022-09-20 11:17:20 +02:00
parent 623adaacd8
commit 1ca14f1c93
2 changed files with 27 additions and 25 deletions

View File

@ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.tracking_controller = tracking_controller self.tracking_controller = tracking_controller
# self.time_steps = np.linspace(0, self.duration, self.traj_steps) # self.time_steps = np.linspace(0, self.duration, self.traj_steps)
# self.traj_gen.set_mp_times(self.time_steps) # self.traj_gen.set_mp_times(self.time_steps)
self.traj_gen.set_duration(self.duration - self.dt, self.dt) self.traj_gen.set_duration(self.duration, self.dt)
# reward computation # reward computation
self.reward_aggregation = reward_aggregation self.reward_aggregation = reward_aggregation
@ -78,11 +78,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
# TODO: remove the - self.dt after Bruces fix. # TODO: remove the - self.dt after Bruces fix.
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt) self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory = get_numpy(self.traj_gen.get_traj_pos()) trajectory = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel()) velocity = get_numpy(self.traj_gen.get_traj_vel())
print(len(trajectory))
if self.do_replanning: if self.do_replanning:
# Remove first part of trajectory as this is already over # Remove first part of trajectory as this is already over
trajectory = trajectory[self.current_traj_steps:] trajectory = trajectory[self.current_traj_steps:]

View File

@ -106,31 +106,31 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
""" """
base_env_id = "HoleReacher-v0" base_env_id = "Reacher5d-v0"
# Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper.
# You can also add other gym.Wrappers in case they are needed. # You can also add other gym.Wrappers in case they are needed.
wrappers = [fancy_gym.envs.classic_control.hole_reacher.MPWrapper] wrappers = [fancy_gym.envs.mujoco.reacher.MPWrapper]
# # For a ProMP # For a ProMP
# trajectory_generator_kwargs = {'trajectory_generator_type': 'promp', trajectory_generator_kwargs = {'trajectory_generator_type': 'promp',
# 'weight_scale': 2} 'weight_scale': 2}
# phase_generator_kwargs = {'phase_generator_type': 'linear'} phase_generator_kwargs = {'phase_generator_type': 'linear'}
# controller_kwargs = {'controller_type': 'velocity'}
# basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
# 'num_basis': 5,
# 'num_basis_zero_start': 1
# }
# For a DMP
trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
'weight_scale': 500}
phase_generator_kwargs = {'phase_generator_type': 'exp',
'alpha_phase': 2.5}
controller_kwargs = {'controller_type': 'velocity'} controller_kwargs = {'controller_type': 'velocity'}
basis_generator_kwargs = {'basis_generator_type': 'rbf', basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
'num_basis': 5 'num_basis': 5,
'num_basis_zero_start': 1
} }
# # For a DMP
# trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
# 'weight_scale': 500}
# phase_generator_kwargs = {'phase_generator_type': 'exp',
# 'alpha_phase': 2.5}
# controller_kwargs = {'controller_type': 'velocity'}
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
# 'num_basis': 5
# }
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={}, env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs, traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs, phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
@ -155,15 +155,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = False
# DMP # DMP
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# #
# # ProMP # # ProMP
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# Altered basis functions # Altered basis functions
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render) # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP # Custom MP
example_fully_custom_mp(seed=10, iterations=1, render=render) example_fully_custom_mp(seed=10, iterations=1, render=render)