add contextual obs option to invalid trajectory callback
This commit is contained in:
parent
c242c32a41
commit
2735e0bf24
@ -169,8 +169,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if traj_is_valid:
|
if traj_is_valid:
|
||||||
self.plan_steps += 1
|
self.plan_steps += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
@ -207,18 +205,19 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||||
self.current_traj_steps += t + 1
|
self.current_traj_steps += t + 1
|
||||||
|
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
infos['positions'] = position
|
infos['positions'] = position
|
||||||
infos['velocities'] = velocity
|
infos['velocities'] = velocity
|
||||||
infos['step_actions'] = actions[:t + 1]
|
infos['step_actions'] = actions[:t + 1]
|
||||||
infos['step_observations'] = observations[:t + 1]
|
infos['step_observations'] = observations[:t + 1]
|
||||||
infos['step_rewards'] = rewards[:t + 1]
|
infos['step_rewards'] = rewards[:t + 1]
|
||||||
|
|
||||||
infos['trajectory_length'] = t + 1
|
infos['trajectory_length'] = t + 1
|
||||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
else:
|
else:
|
||||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
|
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||||
|
self.return_context_observation)
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
|
@ -55,8 +55,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||||
-> Tuple[np.ndarray, float, bool, dict]:
|
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
@ -64,6 +64,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
violate_high_bound_error + violate_low_bound_error
|
violate_high_bound_error + violate_low_bound_error
|
||||||
obs = np.concatenate([self.get_obs(), np.array([0])])
|
obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||||
|
if return_contextual_obs:
|
||||||
|
obs = self.get_obs()
|
||||||
return obs, -invalid_penalty, True, {
|
return obs, -invalid_penalty, True, {
|
||||||
"hit_ball": [False],
|
"hit_ball": [False],
|
||||||
"ball_returned_success": [False],
|
"ball_returned_success": [False],
|
||||||
|
@ -157,17 +157,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom MP
|
# Custom MP
|
||||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user