add weights scaling for box pushing replanning
This commit is contained in:
parent
c457fbbfeb
commit
556bfd0b35
@ -46,7 +46,7 @@ pip install -e .
|
||||
In case you want to use dm_control oder metaworld, you can install them by specifying extras
|
||||
|
||||
```bash
|
||||
pip install -e .[dmc, metaworld]
|
||||
pip install -e .[dmc,metaworld]
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
|
@ -56,6 +56,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
# reward computation
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
||||
# self.traj_gen.basis_gn.show_basis(plot=True)
|
||||
# spaces
|
||||
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||
|
@ -68,8 +68,7 @@ DEFAULT_BB_DICT_ProDMP = {
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'prodmp',
|
||||
'duration': 2.0,
|
||||
'weight_scale': 1.0,
|
||||
'weights_scale': 1.0,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp',
|
||||
@ -497,6 +496,8 @@ for _v in _versions:
|
||||
kwargs_dict_box_pushing_prodmp['name'] = _v
|
||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
|
||||
register(
|
||||
id=_env_id,
|
||||
|
@ -16,8 +16,8 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
[False] * 7, # joints gravity compensation
|
||||
[False] * 3, # position of rod tip
|
||||
[False] * 4, # orientation of rod
|
||||
[False] * 3, # position of box
|
||||
[False] * 4, # orientation of box
|
||||
[True] * 3, # position of box
|
||||
[True] * 4, # orientation of box
|
||||
[True] * 3, # position of target
|
||||
[True] * 4, # orientation of target
|
||||
# [True] * 1, # time
|
||||
|
@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
# DMP
|
||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
|
||||
# ProMP
|
||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# ProDMP
|
||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=1, render=render)
|
||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MP
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
|
15
test/test_replanning_envs.py
Normal file
15
test/test_replanning_envs.py
Normal file
@ -0,0 +1,15 @@
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
|
||||
import fancy_gym
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
||||
|
||||
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
||||
|
||||
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
|
||||
def test_prodmp_envs(env_id: str):
|
||||
"""Tests that ProDMP environments run without errors using random actions."""
|
||||
run_env(env_id)
|
Loading…
Reference in New Issue
Block a user