diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index fb7328c..9b3e95e 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.tracking_controller = tracking_controller # self.time_steps = np.linspace(0, self.duration, self.traj_steps) # self.traj_gen.set_mp_times(self.time_steps) - self.traj_gen.set_duration(self.duration - self.dt, self.dt) + self.traj_gen.set_duration(self.duration, self.dt) # reward computation self.reward_aggregation = reward_aggregation @@ -78,11 +78,13 @@ class BlackBoxWrapper(gym.ObservationWrapper): bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) # TODO: remove the - self.dt after Bruces fix. - self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt) + self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, self.dt) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) trajectory = get_numpy(self.traj_gen.get_traj_pos()) velocity = get_numpy(self.traj_gen.get_traj_vel()) + print(len(trajectory)) + if self.do_replanning: # Remove first part of trajectory as this is already over trajectory = trajectory[self.current_traj_steps:] diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index 22e95ac..3794112 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -106,31 +106,31 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): """ - base_env_id = "HoleReacher-v0" + base_env_id = "Reacher5d-v0" # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. # You can also add other gym.Wrappers in case they are needed. - wrappers = [fancy_gym.envs.classic_control.hole_reacher.MPWrapper] + wrappers = [fancy_gym.envs.mujoco.reacher.MPWrapper] - # # For a ProMP - # trajectory_generator_kwargs = {'trajectory_generator_type': 'promp', - # 'weight_scale': 2} - # phase_generator_kwargs = {'phase_generator_type': 'linear'} - # controller_kwargs = {'controller_type': 'velocity'} - # basis_generator_kwargs = {'basis_generator_type': 'zero_rbf', - # 'num_basis': 5, - # 'num_basis_zero_start': 1 - # } - - # For a DMP - trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp', - 'weight_scale': 500} - phase_generator_kwargs = {'phase_generator_type': 'exp', - 'alpha_phase': 2.5} + # For a ProMP + trajectory_generator_kwargs = {'trajectory_generator_type': 'promp', + 'weight_scale': 2} + phase_generator_kwargs = {'phase_generator_type': 'linear'} controller_kwargs = {'controller_type': 'velocity'} - basis_generator_kwargs = {'basis_generator_type': 'rbf', - 'num_basis': 5 + basis_generator_kwargs = {'basis_generator_type': 'zero_rbf', + 'num_basis': 5, + 'num_basis_zero_start': 1 } + + # # For a DMP + # trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp', + # 'weight_scale': 500} + # phase_generator_kwargs = {'phase_generator_type': 'exp', + # 'alpha_phase': 2.5} + # controller_kwargs = {'controller_type': 'velocity'} + # basis_generator_kwargs = {'basis_generator_type': 'rbf', + # 'num_basis': 5 + # } env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={}, traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs, phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs, @@ -155,15 +155,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': - render = True + render = False # DMP - example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) + # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # # # ProMP - example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) + # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # Altered basis functions - obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render) + # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # Custom MP example_fully_custom_mp(seed=10, iterations=1, render=render)