temporal updates
This commit is contained in:
parent
2a39a72af0
commit
d384e6e764
@ -81,13 +81,15 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# condition value
|
# condition value
|
||||||
self.desired_conditioning = True
|
self.desired_conditioning = False
|
||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
self.condition_vel = None
|
self.condition_vel = None
|
||||||
|
|
||||||
self.max_planning_times = max_planning_times
|
self.max_planning_times = max_planning_times
|
||||||
self.plan_counts = 0
|
self.plan_counts = 0
|
||||||
|
|
||||||
|
self.tau_first_prediction = None
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
if self.return_context_observation:
|
if self.return_context_observation:
|
||||||
@ -96,7 +98,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
return observation.astype(self.observation_space.dtype)
|
return observation.astype(self.observation_space.dtype)
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
duration = self.duration
|
# duration = self.duration
|
||||||
|
# duration = self.duration - self.current_traj_steps * self.dt
|
||||||
|
duration = 2.
|
||||||
if self.learn_sub_trajectories:
|
if self.learn_sub_trajectories:
|
||||||
duration = None
|
duration = None
|
||||||
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
||||||
@ -117,7 +121,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
|
self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
|
||||||
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
|
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
|
||||||
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
# self.traj_gen.set_duration(duration, self.dt)
|
||||||
|
self.traj_gen.set_duration(self.tau_first_prediction, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
||||||
@ -160,6 +165,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
|
|
||||||
|
if self.plan_counts == 0:
|
||||||
|
self.tau_first_prediction = action[0]
|
||||||
|
|
||||||
## tricky part, only use weights basis
|
## tricky part, only use weights basis
|
||||||
# basis_weights = action.reshape(7, -1)
|
# basis_weights = action.reshape(7, -1)
|
||||||
# goal_weights = np.zeros((7, 1))
|
# goal_weights = np.zeros((7, 1))
|
||||||
@ -253,5 +261,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
self.plan_counts = 0
|
self.plan_counts = 0
|
||||||
|
self.tau_first_prediction = None
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
@ -254,7 +254,7 @@ for ctxt_dim in [2, 4]:
|
|||||||
register(
|
register(
|
||||||
id='TableTennis{}D-v0'.format(ctxt_dim),
|
id='TableTennis{}D-v0'.format(ctxt_dim),
|
||||||
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
|
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
|
||||||
max_episode_steps=350,
|
max_episode_steps=500,
|
||||||
kwargs={
|
kwargs={
|
||||||
"ctxt_dim": ctxt_dim,
|
"ctxt_dim": ctxt_dim,
|
||||||
'frame_skip': 4
|
'frame_skip': 4
|
||||||
@ -523,7 +523,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
@ -555,6 +555,35 @@ for _v in _versions:
|
|||||||
kwargs=kwargs_dict_tt_promp
|
kwargs=kwargs_dict_tt_promp
|
||||||
)
|
)
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
for _v in _versions:
|
||||||
|
_name = _v.split("-")
|
||||||
|
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||||
|
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
||||||
|
kwargs_dict_tt_prodmp['name'] = _v
|
||||||
|
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||||
|
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0
|
||||||
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
|
||||||
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
|
||||||
|
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
||||||
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
|
||||||
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
|
||||||
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
|
||||||
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
|
||||||
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||||
|
kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
|
||||||
|
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||||
|
# kwargs_dict_tt_prodmp['black_box_kwargs']['duration'] = 4.
|
||||||
|
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 4
|
||||||
|
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_tt_prodmp
|
||||||
|
)
|
||||||
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
#
|
#
|
||||||
# ## Walker2DJump
|
# ## Walker2DJump
|
||||||
# _versions = ['Walker2DJump-v0']
|
# _versions = ['Walker2DJump-v0']
|
||||||
|
@ -127,8 +127,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
|
||||||
self._goal_pos = self._generate_goal_pos(random=True)
|
self._goal_pos = self._generate_goal_pos(random=False)
|
||||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
||||||
@ -214,6 +214,6 @@ if __name__ == "__main__":
|
|||||||
for _ in range(2000):
|
for _ in range(2000):
|
||||||
env.render("human")
|
env.render("human")
|
||||||
obs, reward, done, info = env.step(np.zeros(7))
|
obs, reward, done, info = env.step(np.zeros(7))
|
||||||
print(reward)
|
# print(reward)
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
@ -40,6 +40,8 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
|||||||
# Now the action space is not the raw action but the parametrization of the trajectory generator,
|
# Now the action space is not the raw action but the parametrization of the trajectory generator,
|
||||||
# such as a ProMP
|
# such as a ProMP
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
|
# ac[0] = 0.6866657733917236
|
||||||
|
# ac[1] = 0.08587364107370377
|
||||||
# This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
|
# This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
|
||||||
# full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
|
# full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
|
||||||
# to the return of a trajectory. Default is the sum over the step-wise rewards.
|
# to the return of a trajectory. Default is the sum over the step-wise rewards.
|
||||||
@ -50,7 +52,8 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
|||||||
if done:
|
if done:
|
||||||
# print(reward)
|
# print(reward)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("steps: {}".format(info["num_steps"][-1]))
|
print("=================New Episode======================")
|
||||||
|
# print("steps: {}".format(info["num_steps"][-1]))
|
||||||
|
|
||||||
|
|
||||||
def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render=True):
|
def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render=True):
|
||||||
@ -158,17 +161,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = False
|
||||||
# DMP
|
# DMP
|
||||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
example_mp("TableTennis4DProMP-v0", seed=10, iterations=10, render=True)
|
# example_mp("TableTennis4DProMP-v0", seed=10, iterations=10, render=True)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
# example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
# example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
||||||
|
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=5000, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user