new reacher mp wrapper for Philipp
This commit is contained in:
parent
7bd9848c31
commit
819fca1b2e
@ -668,21 +668,33 @@ for _v in _versions:
|
|||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{_v}",
|
"name": f"alr_envs:{_v}",
|
||||||
"wrappers": [mujoco.reacher.MPWrapper],
|
"wrappers": [mujoco.reacher.NewMPWrapper],
|
||||||
"mp_kwargs": {
|
"ep_wrapper_kwargs": {
|
||||||
"num_dof": 5 if "long" not in _v.lower() else 7,
|
"weight_scale": 1
|
||||||
"num_basis": 2,
|
},
|
||||||
"duration": 4,
|
"movement_primitives_kwargs": {
|
||||||
"policy_type": "motor",
|
'movement_primitives_type': 'promp',
|
||||||
"weights_scale": 5,
|
'action_dim': 5 if "long" not in _v.lower() else 7
|
||||||
"zero_start": True,
|
},
|
||||||
"policy_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
"p_gains": 1,
|
'phase_generator_type': 'linear',
|
||||||
"d_gains": 0.1
|
'delay': 0,
|
||||||
}
|
'tau': 4, # initial value
|
||||||
|
'learn_tau': False,
|
||||||
|
'learn_delay': False
|
||||||
|
},
|
||||||
|
"controller_kwargs": {
|
||||||
|
'controller_type': 'motor',
|
||||||
|
"p_gains": 1,
|
||||||
|
"d_gains": 0.1
|
||||||
|
},
|
||||||
|
"basis_generator_kwargs": {
|
||||||
|
'basis_generator_type': 'zero_rbf',
|
||||||
|
'num_basis': 2,
|
||||||
|
'num_basis_zero_start': 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1275,4 +1287,4 @@ for _v in _versions:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
@ -1 +1,2 @@
|
|||||||
from .mp_wrapper import MPWrapper
|
from .mp_wrapper import MPWrapper
|
||||||
|
from .new_mp_wrapper import NewMPWrapper
|
||||||
|
@ -149,4 +149,4 @@ if __name__ == '__main__':
|
|||||||
if d:
|
if d:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@ -3,7 +3,7 @@ from typing import Union, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(EpisodicWrapper):
|
class NewMPWrapper(EpisodicWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
@ -21,4 +21,4 @@ class MPWrapper(EpisodicWrapper):
|
|||||||
[False] * 3, # goal distance
|
[False] * 3, # goal distance
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
# self.get_body_com("target"), # only return target to make problem harder
|
||||||
[False], # step
|
[False], # step
|
||||||
])
|
])
|
||||||
|
Loading…
Reference in New Issue
Block a user