renameing alr module and updating tests
This commit is contained in:
parent
79c26681c9
commit
0339361656
@ -107,7 +107,7 @@ keys `DMP` and `ProMP` that store a list of available environment names.
|
|||||||
import alr_envs
|
import alr_envs
|
||||||
|
|
||||||
print("Custom MP tasks:")
|
print("Custom MP tasks:")
|
||||||
print(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS)
|
print(alr_envs.ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||||
|
|
||||||
print("OpenAI Gym MP tasks:")
|
print("OpenAI Gym MP tasks:")
|
||||||
print(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS)
|
print(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS)
|
||||||
@ -116,7 +116,7 @@ print("Deepmind Control MP tasks:")
|
|||||||
print(alr_envs.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
print(alr_envs.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||||
|
|
||||||
print("MetaWorld MP tasks:")
|
print("MetaWorld MP tasks:")
|
||||||
print(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS)
|
print(alr_envs.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||||
```
|
```
|
||||||
|
|
||||||
### How to create a new MP task
|
### How to create a new MP task
|
||||||
|
@ -2,13 +2,13 @@ from alr_envs import dmc, meta, open_ai
|
|||||||
from alr_envs.utils.make_env_helpers import make, make_bb, make_rank
|
from alr_envs.utils.make_env_helpers import make, make_bb, make_rank
|
||||||
|
|
||||||
# Convenience function for all MP environments
|
# Convenience function for all MP environments
|
||||||
from .alr import ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS
|
from .envs import ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||||
from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||||
from .meta import ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS
|
from .meta import ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||||
from .open_ai import ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS
|
from .open_ai import ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS
|
||||||
|
|
||||||
ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {
|
ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {
|
||||||
key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] +
|
key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] +
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS[key] +
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS[key] +
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS[key]
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key]
|
||||||
for key, value in ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS.items()}
|
for key, value in ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS.items()}
|
||||||
|
@ -16,7 +16,7 @@ from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPER
|
|||||||
from .mujoco.reacher.reacher import ReacherEnv
|
from .mujoco.reacher.reacher import ReacherEnv
|
||||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||||
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
@ -63,7 +63,7 @@ DEFAULT_BB_DICT_DMP = {
|
|||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
register(
|
register(
|
||||||
id='SimpleReacher-v0',
|
id='SimpleReacher-v0',
|
||||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 2,
|
"n_links": 2,
|
||||||
@ -72,7 +72,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='LongSimpleReacher-v0',
|
id='LongSimpleReacher-v0',
|
||||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -83,7 +83,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacher-v0',
|
id='ViaPointReacher-v0',
|
||||||
entry_point='alr_envs.alr.classic_control:ViaPointReacherEnv',
|
entry_point='alr_envs.envs.classic_control:ViaPointReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -95,7 +95,7 @@ register(
|
|||||||
## Hole Reacher
|
## Hole Reacher
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v0',
|
id='HoleReacher-v0',
|
||||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -115,7 +115,7 @@ register(
|
|||||||
for _dims in [5, 7]:
|
for _dims in [5, 7]:
|
||||||
register(
|
register(
|
||||||
id=f'Reacher{_dims}d-v0',
|
id=f'Reacher{_dims}d-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": _dims,
|
"n_links": _dims,
|
||||||
@ -124,7 +124,7 @@ for _dims in [5, 7]:
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'Reacher{_dims}dSparse-v0',
|
id=f'Reacher{_dims}dSparse-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": True,
|
"sparse": True,
|
||||||
@ -134,7 +134,7 @@ for _dims in [5, 7]:
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='HopperJumpSparse-v0',
|
id='HopperJumpSparse-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": True,
|
"sparse": True,
|
||||||
@ -143,7 +143,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='HopperJump-v0',
|
id='HopperJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": False,
|
"sparse": False,
|
||||||
@ -155,43 +155,43 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRAntJump-v0',
|
id='ALRAntJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHalfCheetahJump-v0',
|
id='ALRHalfCheetahJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HopperJumpOnBox-v0',
|
id='HopperJumpOnBox-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrow-v0',
|
id='ALRHopperThrow-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrowInBasket-v0',
|
id='ALRHopperThrowInBasket-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRWalker2DJump-v0',
|
id='ALRWalker2DJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='BeerPong-v0',
|
id='BeerPong-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:BeerPongEnv',
|
entry_point='alr_envs.envs.mujoco:BeerPongEnv',
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,14 +199,14 @@ register(
|
|||||||
# only one time step, i.e. we simulate until the end of th episode
|
# only one time step, i.e. we simulate until the end of th episode
|
||||||
register(
|
register(
|
||||||
id='BeerPongStepBased-v0',
|
id='BeerPongStepBased-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:BeerPongEnvStepBasedEpisodicReward',
|
entry_point='alr_envs.envs.mujoco:BeerPongEnvStepBasedEpisodicReward',
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Beerpong with episodic reward, but fixed release time step
|
# Beerpong with episodic reward, but fixed release time step
|
||||||
register(
|
register(
|
||||||
id='BeerPongFixedRelease-v0',
|
id='BeerPongFixedRelease-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:BeerPongEnvFixedReleaseStep',
|
entry_point='alr_envs.envs.mujoco:BeerPongEnvFixedReleaseStep',
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_simple_reacher_dmp
|
kwargs=kwargs_dict_simple_reacher_dmp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
@ -242,7 +242,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_simple_reacher_promp
|
kwargs=kwargs_dict_simple_reacher_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
# Viapoint reacher
|
# Viapoint reacher
|
||||||
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
@ -257,7 +257,7 @@ register(
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs=kwargs_dict_via_point_reacher_dmp
|
kwargs=kwargs_dict_via_point_reacher_dmp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
|
||||||
|
|
||||||
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
|
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
|
||||||
@ -268,7 +268,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_via_point_reacher_promp
|
kwargs=kwargs_dict_via_point_reacher_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
||||||
|
|
||||||
## Hole Reacher
|
## Hole Reacher
|
||||||
_versions = ["HoleReacher-v0"]
|
_versions = ["HoleReacher-v0"]
|
||||||
@ -288,7 +288,7 @@ for _v in _versions:
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs=kwargs_dict_hole_reacher_dmp
|
kwargs=kwargs_dict_hole_reacher_dmp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
@ -301,7 +301,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_hole_reacher_promp
|
kwargs=kwargs_dict_hole_reacher_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
## ReacherNd
|
## ReacherNd
|
||||||
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
|
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
|
||||||
@ -320,7 +320,7 @@ for _v in _versions:
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs=kwargs_dict_reacherNd_dmp
|
kwargs=kwargs_dict_reacherNd_dmp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
@ -333,7 +333,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_alr_reacher_promp
|
kwargs=kwargs_dict_alr_reacher_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
########################################################################################################################
|
########################################################################################################################
|
||||||
## Beerpong ProMP
|
## Beerpong ProMP
|
||||||
@ -354,7 +354,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_bp_promp
|
kwargs=kwargs_dict_bp_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
### BP with Fixed release
|
### BP with Fixed release
|
||||||
_versions = ["BeerPongStepBased-v0", "BeerPongFixedRelease-v0"]
|
_versions = ["BeerPongStepBased-v0", "BeerPongFixedRelease-v0"]
|
||||||
@ -374,7 +374,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_bp_promp
|
kwargs=kwargs_dict_bp_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
########################################################################################################################
|
########################################################################################################################
|
||||||
|
|
||||||
## Table Tennis needs to be fixed according to Zhou's implementation
|
## Table Tennis needs to be fixed according to Zhou's implementation
|
||||||
@ -395,7 +395,7 @@ for _v in _versions:
|
|||||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
# kwargs=kwargs_dict_ant_jump_promp
|
# kwargs=kwargs_dict_ant_jump_promp
|
||||||
# )
|
# )
|
||||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
#
|
#
|
||||||
# ########################################################################################################################
|
# ########################################################################################################################
|
||||||
#
|
#
|
||||||
@ -412,7 +412,7 @@ for _v in _versions:
|
|||||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
# kwargs=kwargs_dict_halfcheetah_jump_promp
|
# kwargs=kwargs_dict_halfcheetah_jump_promp
|
||||||
# )
|
# )
|
||||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
#
|
#
|
||||||
# ########################################################################################################################
|
# ########################################################################################################################
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ for _v in _versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_hopper_jump_promp
|
kwargs=kwargs_dict_hopper_jump_promp
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
# ########################################################################################################################
|
# ########################################################################################################################
|
||||||
#
|
#
|
||||||
@ -451,13 +451,13 @@ for _v in _versions:
|
|||||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
# kwargs=kwargs_dict_walker2d_jump_promp
|
# kwargs=kwargs_dict_walker2d_jump_promp
|
||||||
# )
|
# )
|
||||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
### Depricated, we will not provide non random starts anymore
|
### Depricated, we will not provide non random starts anymore
|
||||||
"""
|
"""
|
||||||
register(
|
register(
|
||||||
id='SimpleReacher-v1',
|
id='SimpleReacher-v1',
|
||||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 2,
|
"n_links": 2,
|
||||||
@ -467,7 +467,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='LongSimpleReacher-v1',
|
id='LongSimpleReacher-v1',
|
||||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -476,7 +476,7 @@ register(
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v1',
|
id='HoleReacher-v1',
|
||||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -491,7 +491,7 @@ register(
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v2',
|
id='HoleReacher-v2',
|
||||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -508,7 +508,7 @@ register(
|
|||||||
# CtxtFree are v0, Contextual are v1
|
# CtxtFree are v0, Contextual are v1
|
||||||
register(
|
register(
|
||||||
id='ALRAntJump-v0',
|
id='ALRAntJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
|
||||||
@ -518,7 +518,7 @@ register(
|
|||||||
# CtxtFree are v0, Contextual are v1
|
# CtxtFree are v0, Contextual are v1
|
||||||
register(
|
register(
|
||||||
id='ALRHalfCheetahJump-v0',
|
id='ALRHalfCheetahJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
@ -527,7 +527,7 @@ register(
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRHopperJump-v0',
|
id='ALRHopperJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
@ -545,7 +545,7 @@ for i in _vs:
|
|||||||
_env_id = f'ALRReacher{i}-v0'
|
_env_id = f'ALRReacher{i}-v0'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"steps_before_reward": 0,
|
"steps_before_reward": 0,
|
||||||
@ -558,7 +558,7 @@ for i in _vs:
|
|||||||
_env_id = f'ALRReacherSparse{i}-v0'
|
_env_id = f'ALRReacherSparse{i}-v0'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"steps_before_reward": 200,
|
"steps_before_reward": 200,
|
||||||
@ -617,7 +617,7 @@ for i in _vs:
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperJumpOnBox-v0',
|
id='ALRHopperJumpOnBox-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
@ -626,7 +626,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrow-v0',
|
id='ALRHopperThrow-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
@ -635,7 +635,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrowInBasket-v0',
|
id='ALRHopperThrowInBasket-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
@ -644,7 +644,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRWalker2DJump-v0',
|
id='ALRWalker2DJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
@ -652,13 +652,13 @@ for i in _vs:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
register(id='TableTennis2DCtxt-v1',
|
register(id='TableTennis2DCtxt-v1',
|
||||||
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
entry_point='alr_envs.envs.mujoco:TTEnvGym',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS,
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBeerPong-v0',
|
id='ALRBeerPong-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
entry_point='alr_envs.envs.mujoco:ALRBeerBongEnv',
|
||||||
max_episode_steps=300,
|
max_episode_steps=300,
|
||||||
kwargs={
|
kwargs={
|
||||||
"rndm_goal": False,
|
"rndm_goal": False,
|
@ -7,7 +7,7 @@ from gym import spaces
|
|||||||
from gym.core import ObsType
|
from gym.core import ObsType
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.utils import intersect
|
from alr_envs.envs.classic_control.utils import intersect
|
||||||
|
|
||||||
|
|
||||||
class BaseReacherEnv(gym.Env, ABC):
|
class BaseReacherEnv(gym.Env, ABC):
|
@ -2,7 +2,7 @@ from abc import ABC
|
|||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||||
|
|
||||||
|
|
||||||
class BaseReacherDirectEnv(BaseReacherEnv, ABC):
|
class BaseReacherDirectEnv(BaseReacherEnv, ABC):
|
@ -2,7 +2,7 @@ from abc import ABC
|
|||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||||
|
|
||||||
|
|
||||||
class BaseReacherTorqueEnv(BaseReacherEnv, ABC):
|
class BaseReacherTorqueEnv(BaseReacherEnv, ABC):
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from gym.core import ObsType
|
from gym.core import ObsType
|
||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(BaseReacherDirectEnv):
|
class HoleReacherEnv(BaseReacherDirectEnv):
|
||||||
@ -41,13 +41,13 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||||
|
|
||||||
if rew_fct == "simple":
|
if rew_fct == "simple":
|
||||||
from alr_envs.alr.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
|
from alr_envs.envs.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
|
||||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
||||||
elif rew_fct == "vel_acc":
|
elif rew_fct == "vel_acc":
|
||||||
from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward
|
from alr_envs.envs.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward
|
||||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
||||||
elif rew_fct == "unbounded":
|
elif rew_fct == "unbounded":
|
||||||
from alr_envs.alr.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward
|
from alr_envs.envs.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward
|
||||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision)
|
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown reward function {}".format(rew_fct))
|
raise ValueError("Unknown reward function {}".format(rew_fct))
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.core import ObsType
|
from gym.core import ObsType
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
|
from alr_envs.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherEnv(BaseReacherTorqueEnv):
|
class SimpleReacherEnv(BaseReacherTorqueEnv):
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from gym.core import ObsType
|
from gym.core import ObsType
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
@ -55,7 +55,7 @@ class AntJumpEnv(AntEnv):
|
|||||||
|
|
||||||
costs = ctrl_cost + contact_cost
|
costs = ctrl_cost + contact_cost
|
||||||
|
|
||||||
done = height < 0.3 # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
done = bool(height < 0.3) # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
||||||
|
|
||||||
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done:
|
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done:
|
||||||
# -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled
|
# -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled
|
||||||
@ -84,8 +84,8 @@ class AntJumpEnv(AntEnv):
|
|||||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
self.goal = self.np_random.uniform(1.0, 2.5,
|
# goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
|
||||||
1) # goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
|
self.goal = self.np_random.uniform(1.0, 2.5, 1)
|
||||||
return super().reset()
|
return super().reset()
|
||||||
|
|
||||||
# reset_model had to be implemented in every env to make it deterministic
|
# reset_model had to be implemented in every env to make it deterministic
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
|
||||||
from alr_envs.alr.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
|
from alr_envs.envs.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
|
||||||
|
|
||||||
|
|
||||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user