new reacher mp wrapper for Philipp

This commit is contained in:
Onur 2022-07-07 09:39:20 +02:00
parent 7bd9848c31
commit 819fca1b2e
4 changed files with 31 additions and 18 deletions

View File

@ -668,21 +668,33 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper',
kwargs={ kwargs={
"name": f"alr_envs:{_v}", "name": f"alr_envs:{_v}",
"wrappers": [mujoco.reacher.MPWrapper], "wrappers": [mujoco.reacher.NewMPWrapper],
"mp_kwargs": { "ep_wrapper_kwargs": {
"num_dof": 5 if "long" not in _v.lower() else 7, "weight_scale": 1
"num_basis": 2, },
"duration": 4, "movement_primitives_kwargs": {
"policy_type": "motor", 'movement_primitives_type': 'promp',
"weights_scale": 5, 'action_dim': 5 if "long" not in _v.lower() else 7
"zero_start": True, },
"policy_kwargs": { "phase_generator_kwargs": {
"p_gains": 1, 'phase_generator_type': 'linear',
"d_gains": 0.1 'delay': 0,
} 'tau': 4, # initial value
'learn_tau': False,
'learn_delay': False
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1,
"d_gains": 0.1
},
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 2,
'num_basis_zero_start': 1
} }
} }
) )
@ -1275,4 +1287,4 @@ for _v in _versions:
} }
} }
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -1 +1,2 @@
from .mp_wrapper import MPWrapper from .mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper

View File

@ -149,4 +149,4 @@ if __name__ == '__main__':
if d: if d:
env.reset() env.reset()
env.close() env.close()

View File

@ -3,7 +3,7 @@ from typing import Union, Tuple
import numpy as np import numpy as np
class MPWrapper(EpisodicWrapper): class NewMPWrapper(EpisodicWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
@ -21,4 +21,4 @@ class MPWrapper(EpisodicWrapper):
[False] * 3, # goal distance [False] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder # self.get_body_com("target"), # only return target to make problem harder
[False], # step [False], # step
]) ])