temporal updates
This commit is contained in:
		
							parent
							
								
									2a39a72af0
								
							
						
					
					
						commit
						d384e6e764
					
				@ -81,13 +81,15 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
 | 
			
		||||
        # condition value
 | 
			
		||||
        self.desired_conditioning = True
 | 
			
		||||
        self.desired_conditioning = False
 | 
			
		||||
        self.condition_pos = None
 | 
			
		||||
        self.condition_vel = None
 | 
			
		||||
 | 
			
		||||
        self.max_planning_times = max_planning_times
 | 
			
		||||
        self.plan_counts = 0
 | 
			
		||||
 | 
			
		||||
        self.tau_first_prediction = None
 | 
			
		||||
 | 
			
		||||
    def observation(self, observation):
 | 
			
		||||
        # return context space if we are
 | 
			
		||||
        if self.return_context_observation:
 | 
			
		||||
@ -96,7 +98,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
        return observation.astype(self.observation_space.dtype)
 | 
			
		||||
 | 
			
		||||
    def get_trajectory(self, action: np.ndarray) -> Tuple:
 | 
			
		||||
        duration = self.duration
 | 
			
		||||
        # duration = self.duration
 | 
			
		||||
        # duration = self.duration - self.current_traj_steps * self.dt
 | 
			
		||||
        duration = 2.
 | 
			
		||||
        if self.learn_sub_trajectories:
 | 
			
		||||
            duration = None
 | 
			
		||||
            # reset  with every new call as we need to set all arguments, such as tau, delay, again.
 | 
			
		||||
@ -117,7 +121,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
        self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
 | 
			
		||||
        self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
 | 
			
		||||
        self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
 | 
			
		||||
        self.traj_gen.set_duration(duration, self.dt)
 | 
			
		||||
        # self.traj_gen.set_duration(duration, self.dt)
 | 
			
		||||
        self.traj_gen.set_duration(self.tau_first_prediction, self.dt)
 | 
			
		||||
        # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
 | 
			
		||||
        position = get_numpy(self.traj_gen.get_traj_pos())
 | 
			
		||||
        velocity = get_numpy(self.traj_gen.get_traj_vel())
 | 
			
		||||
@ -160,6 +165,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
    def step(self, action: np.ndarray):
 | 
			
		||||
        """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
 | 
			
		||||
 | 
			
		||||
        if self.plan_counts == 0:
 | 
			
		||||
            self.tau_first_prediction = action[0]
 | 
			
		||||
 | 
			
		||||
        ## tricky part, only use weights basis
 | 
			
		||||
        # basis_weights = action.reshape(7, -1)
 | 
			
		||||
        # goal_weights = np.zeros((7, 1))
 | 
			
		||||
@ -253,5 +261,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
 | 
			
		||||
    def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
 | 
			
		||||
        self.current_traj_steps = 0
 | 
			
		||||
        self.plan_counts = 0
 | 
			
		||||
        self.tau_first_prediction = None
 | 
			
		||||
        self.traj_gen.reset()
 | 
			
		||||
        return super(BlackBoxWrapper, self).reset()
 | 
			
		||||
 | 
			
		||||
@ -254,7 +254,7 @@ for ctxt_dim in [2, 4]:
 | 
			
		||||
    register(
 | 
			
		||||
        id='TableTennis{}D-v0'.format(ctxt_dim),
 | 
			
		||||
        entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
 | 
			
		||||
        max_episode_steps=350,
 | 
			
		||||
        max_episode_steps=500,
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "ctxt_dim": ctxt_dim,
 | 
			
		||||
            'frame_skip': 4
 | 
			
		||||
@ -523,7 +523,7 @@ for _v in _versions:
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
 | 
			
		||||
    kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
 | 
			
		||||
    register(
 | 
			
		||||
        id=_env_id,
 | 
			
		||||
@ -555,6 +555,35 @@ for _v in _versions:
 | 
			
		||||
        kwargs=kwargs_dict_tt_promp
 | 
			
		||||
    )
 | 
			
		||||
    ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
 | 
			
		||||
 | 
			
		||||
for _v in _versions:
 | 
			
		||||
    _name = _v.split("-")
 | 
			
		||||
    _env_id = f'{_name[0]}ProDMP-{_name[1]}'
 | 
			
		||||
    kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
 | 
			
		||||
    kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
 | 
			
		||||
    kwargs_dict_tt_prodmp['name'] = _v
 | 
			
		||||
    kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
 | 
			
		||||
    kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
 | 
			
		||||
    kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0
 | 
			
		||||
    kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
 | 
			
		||||
    kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
 | 
			
		||||
    kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
 | 
			
		||||
    kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
 | 
			
		||||
    kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
 | 
			
		||||
    kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
 | 
			
		||||
    kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
 | 
			
		||||
    kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
 | 
			
		||||
    kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
 | 
			
		||||
    kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
 | 
			
		||||
    # kwargs_dict_tt_prodmp['black_box_kwargs']['duration'] = 4.
 | 
			
		||||
    kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 4
 | 
			
		||||
    kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
 | 
			
		||||
    register(
 | 
			
		||||
        id=_env_id,
 | 
			
		||||
        entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
 | 
			
		||||
        kwargs=kwargs_dict_tt_prodmp
 | 
			
		||||
    )
 | 
			
		||||
    ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
 | 
			
		||||
#
 | 
			
		||||
# ## Walker2DJump
 | 
			
		||||
# _versions = ['Walker2DJump-v0']
 | 
			
		||||
 | 
			
		||||
@ -127,8 +127,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
 | 
			
		||||
    def reset_model(self):
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
        self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
 | 
			
		||||
        self._goal_pos = self._generate_goal_pos(random=True)
 | 
			
		||||
        self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
 | 
			
		||||
        self._goal_pos = self._generate_goal_pos(random=False)
 | 
			
		||||
        self.data.joint("tar_x").qpos = self._init_ball_state[0]
 | 
			
		||||
        self.data.joint("tar_y").qpos = self._init_ball_state[1]
 | 
			
		||||
        self.data.joint("tar_z").qpos = self._init_ball_state[2]
 | 
			
		||||
@ -214,6 +214,6 @@ if __name__ == "__main__":
 | 
			
		||||
        for _ in range(2000):
 | 
			
		||||
            env.render("human")
 | 
			
		||||
            obs, reward, done, info = env.step(np.zeros(7))
 | 
			
		||||
            print(reward)
 | 
			
		||||
            # print(reward)
 | 
			
		||||
            if done:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,8 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
 | 
			
		||||
        # Now the action space is not the raw action but the parametrization of the trajectory generator,
 | 
			
		||||
        # such as a ProMP
 | 
			
		||||
        ac = env.action_space.sample()
 | 
			
		||||
        # ac[0] = 0.6866657733917236
 | 
			
		||||
        # ac[1] = 0.08587364107370377
 | 
			
		||||
        # This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
 | 
			
		||||
        # full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
 | 
			
		||||
        # to the return of a trajectory. Default is the sum over the step-wise rewards.
 | 
			
		||||
@ -50,7 +52,8 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
 | 
			
		||||
        if done:
 | 
			
		||||
            # print(reward)
 | 
			
		||||
            obs = env.reset()
 | 
			
		||||
            print("steps: {}".format(info["num_steps"][-1]))
 | 
			
		||||
            print("=================New Episode======================")
 | 
			
		||||
            # print("steps: {}".format(info["num_steps"][-1]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render=True):
 | 
			
		||||
@ -158,17 +161,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    render = True
 | 
			
		||||
    render = False
 | 
			
		||||
    # DMP
 | 
			
		||||
    # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
 | 
			
		||||
    # ProMP
 | 
			
		||||
    # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
 | 
			
		||||
    # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
    example_mp("TableTennis4DProMP-v0", seed=10, iterations=10, render=True)
 | 
			
		||||
    # example_mp("TableTennis4DProMP-v0", seed=10, iterations=10, render=True)
 | 
			
		||||
 | 
			
		||||
    # ProDMP
 | 
			
		||||
    # example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
 | 
			
		||||
    example_mp("TableTennis4DProDMP-v0", seed=10, iterations=5000, render=render)
 | 
			
		||||
 | 
			
		||||
    # Altered basis functions
 | 
			
		||||
    # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user