updated for new set_duration function.
This commit is contained in:
parent
623adaacd8
commit
1ca14f1c93
@ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.tracking_controller = tracking_controller
|
self.tracking_controller = tracking_controller
|
||||||
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||||
# self.traj_gen.set_mp_times(self.time_steps)
|
# self.traj_gen.set_mp_times(self.time_steps)
|
||||||
self.traj_gen.set_duration(self.duration - self.dt, self.dt)
|
self.traj_gen.set_duration(self.duration, self.dt)
|
||||||
|
|
||||||
# reward computation
|
# reward computation
|
||||||
self.reward_aggregation = reward_aggregation
|
self.reward_aggregation = reward_aggregation
|
||||||
@ -78,11 +78,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||||
# TODO: remove the - self.dt after Bruces fix.
|
# TODO: remove the - self.dt after Bruces fix.
|
||||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt)
|
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||||
|
|
||||||
|
print(len(trajectory))
|
||||||
|
|
||||||
if self.do_replanning:
|
if self.do_replanning:
|
||||||
# Remove first part of trajectory as this is already over
|
# Remove first part of trajectory as this is already over
|
||||||
trajectory = trajectory[self.current_traj_steps:]
|
trajectory = trajectory[self.current_traj_steps:]
|
||||||
|
@ -106,31 +106,31 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_env_id = "HoleReacher-v0"
|
base_env_id = "Reacher5d-v0"
|
||||||
|
|
||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
wrappers = [fancy_gym.envs.classic_control.hole_reacher.MPWrapper]
|
wrappers = [fancy_gym.envs.mujoco.reacher.MPWrapper]
|
||||||
|
|
||||||
# # For a ProMP
|
# For a ProMP
|
||||||
# trajectory_generator_kwargs = {'trajectory_generator_type': 'promp',
|
trajectory_generator_kwargs = {'trajectory_generator_type': 'promp',
|
||||||
# 'weight_scale': 2}
|
'weight_scale': 2}
|
||||||
# phase_generator_kwargs = {'phase_generator_type': 'linear'}
|
phase_generator_kwargs = {'phase_generator_type': 'linear'}
|
||||||
# controller_kwargs = {'controller_type': 'velocity'}
|
|
||||||
# basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
|
|
||||||
# 'num_basis': 5,
|
|
||||||
# 'num_basis_zero_start': 1
|
|
||||||
# }
|
|
||||||
|
|
||||||
# For a DMP
|
|
||||||
trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
|
|
||||||
'weight_scale': 500}
|
|
||||||
phase_generator_kwargs = {'phase_generator_type': 'exp',
|
|
||||||
'alpha_phase': 2.5}
|
|
||||||
controller_kwargs = {'controller_type': 'velocity'}
|
controller_kwargs = {'controller_type': 'velocity'}
|
||||||
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
|
||||||
'num_basis': 5
|
'num_basis': 5,
|
||||||
|
'num_basis_zero_start': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# # For a DMP
|
||||||
|
# trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
|
||||||
|
# 'weight_scale': 500}
|
||||||
|
# phase_generator_kwargs = {'phase_generator_type': 'exp',
|
||||||
|
# 'alpha_phase': 2.5}
|
||||||
|
# controller_kwargs = {'controller_type': 'velocity'}
|
||||||
|
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||||
|
# 'num_basis': 5
|
||||||
|
# }
|
||||||
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
|
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
@ -155,15 +155,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = False
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
#
|
#
|
||||||
# # ProMP
|
# # ProMP
|
||||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom MP
|
# Custom MP
|
||||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user