updated for new set_duration function.
This commit is contained in:
parent
623adaacd8
commit
1ca14f1c93
@ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.tracking_controller = tracking_controller
|
||||
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||
# self.traj_gen.set_mp_times(self.time_steps)
|
||||
self.traj_gen.set_duration(self.duration - self.dt, self.dt)
|
||||
self.traj_gen.set_duration(self.duration, self.dt)
|
||||
|
||||
# reward computation
|
||||
self.reward_aggregation = reward_aggregation
|
||||
@ -78,11 +78,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
||||
# TODO: remove the - self.dt after Bruces fix.
|
||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt)
|
||||
self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
trajectory = get_numpy(self.traj_gen.get_traj_pos())
|
||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||
|
||||
print(len(trajectory))
|
||||
|
||||
if self.do_replanning:
|
||||
# Remove first part of trajectory as this is already over
|
||||
trajectory = trajectory[self.current_traj_steps:]
|
||||
|
@ -106,31 +106,31 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
"""
|
||||
|
||||
base_env_id = "HoleReacher-v0"
|
||||
base_env_id = "Reacher5d-v0"
|
||||
|
||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper.
|
||||
# You can also add other gym.Wrappers in case they are needed.
|
||||
wrappers = [fancy_gym.envs.classic_control.hole_reacher.MPWrapper]
|
||||
wrappers = [fancy_gym.envs.mujoco.reacher.MPWrapper]
|
||||
|
||||
# # For a ProMP
|
||||
# trajectory_generator_kwargs = {'trajectory_generator_type': 'promp',
|
||||
# 'weight_scale': 2}
|
||||
# phase_generator_kwargs = {'phase_generator_type': 'linear'}
|
||||
# controller_kwargs = {'controller_type': 'velocity'}
|
||||
# basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
|
||||
# 'num_basis': 5,
|
||||
# 'num_basis_zero_start': 1
|
||||
# }
|
||||
|
||||
# For a DMP
|
||||
trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
|
||||
'weight_scale': 500}
|
||||
phase_generator_kwargs = {'phase_generator_type': 'exp',
|
||||
'alpha_phase': 2.5}
|
||||
# For a ProMP
|
||||
trajectory_generator_kwargs = {'trajectory_generator_type': 'promp',
|
||||
'weight_scale': 2}
|
||||
phase_generator_kwargs = {'phase_generator_type': 'linear'}
|
||||
controller_kwargs = {'controller_type': 'velocity'}
|
||||
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||
'num_basis': 5
|
||||
basis_generator_kwargs = {'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
}
|
||||
|
||||
# # For a DMP
|
||||
# trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp',
|
||||
# 'weight_scale': 500}
|
||||
# phase_generator_kwargs = {'phase_generator_type': 'exp',
|
||||
# 'alpha_phase': 2.5}
|
||||
# controller_kwargs = {'controller_type': 'velocity'}
|
||||
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||
# 'num_basis': 5
|
||||
# }
|
||||
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
|
||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||
@ -155,15 +155,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
render = False
|
||||
# DMP
|
||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
#
|
||||
# # ProMP
|
||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render)
|
||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MP
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
|
Loading…
Reference in New Issue
Block a user