add weights scaling for box pushing replanning

This commit is contained in:
Hongyi Zhou 2022-10-25 20:10:59 +02:00
parent c457fbbfeb
commit 556bfd0b35
6 changed files with 28 additions and 11 deletions

View File

@ -56,6 +56,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# reward computation # reward computation
self.reward_aggregation = reward_aggregation self.reward_aggregation = reward_aggregation
# self.traj_gen.basis_gn.show_basis(plot=True)
# spaces # spaces
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning) self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
self.traj_gen_action_space = self._get_traj_gen_action_space() self.traj_gen_action_space = self._get_traj_gen_action_space()

View File

@ -68,8 +68,7 @@ DEFAULT_BB_DICT_ProDMP = {
"wrappers": [], "wrappers": [],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp', 'trajectory_generator_type': 'prodmp',
'duration': 2.0, 'weights_scale': 1.0,
'weight_scale': 1.0,
}, },
"phase_generator_kwargs": { "phase_generator_kwargs": {
'phase_generator_type': 'exp', 'phase_generator_type': 'exp',
@ -497,6 +496,8 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['name'] = _v kwargs_dict_box_pushing_prodmp['name'] = _v
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
register( register(
id=_env_id, id=_env_id,

View File

@ -16,8 +16,8 @@ class MPWrapper(RawInterfaceWrapper):
[False] * 7, # joints gravity compensation [False] * 7, # joints gravity compensation
[False] * 3, # position of rod tip [False] * 3, # position of rod tip
[False] * 4, # orientation of rod [False] * 4, # orientation of rod
[False] * 3, # position of box [True] * 3, # position of box
[False] * 4, # orientation of box [True] * 4, # orientation of box
[True] * 3, # position of target [True] * 3, # position of target
[True] * 4, # orientation of target [True] * 4, # orientation of target
# [True] * 1, # time # [True] * 1, # time

View File

@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = True
# DMP # DMP
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# ProMP # ProMP
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
# ProDMP # ProDMP
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=1, render=render) example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
# Altered basis functions # Altered basis functions
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP # Custom MP
example_fully_custom_mp(seed=10, iterations=1, render=render) # example_fully_custom_mp(seed=10, iterations=1, render=render)

View File

@ -0,0 +1,15 @@
from itertools import chain
import pytest
import fancy_gym
from test.utils import run_env, run_env_determinism
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
def test_prodmp_envs(env_id: str):
"""Tests that ProDMP environments run without errors using random actions."""
run_env(env_id)