renameing alr module and updating tests
This commit is contained in:
parent
79c26681c9
commit
0339361656
@ -107,7 +107,7 @@ keys `DMP` and `ProMP` that store a list of available environment names.
|
||||
import alr_envs
|
||||
|
||||
print("Custom MP tasks:")
|
||||
print(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS)
|
||||
print(alr_envs.ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||
|
||||
print("OpenAI Gym MP tasks:")
|
||||
print(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS)
|
||||
@ -116,7 +116,7 @@ print("Deepmind Control MP tasks:")
|
||||
print(alr_envs.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||
|
||||
print("MetaWorld MP tasks:")
|
||||
print(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS)
|
||||
print(alr_envs.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
|
||||
```
|
||||
|
||||
### How to create a new MP task
|
||||
|
@ -2,13 +2,13 @@ from alr_envs import dmc, meta, open_ai
|
||||
from alr_envs.utils.make_env_helpers import make, make_bb, make_rank
|
||||
|
||||
# Convenience function for all MP environments
|
||||
from .alr import ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS
|
||||
from .envs import ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||
from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||
from .meta import ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS
|
||||
from .meta import ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||
from .open_ai import ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS
|
||||
|
||||
ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {
|
||||
key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] +
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS[key] +
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS[key]
|
||||
for key, value in ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS.items()}
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key]
|
||||
for key, value in ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS.items()}
|
||||
|
@ -16,7 +16,7 @@ from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPER
|
||||
from .mujoco.reacher.reacher import ReacherEnv
|
||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
@ -63,7 +63,7 @@ DEFAULT_BB_DICT_DMP = {
|
||||
## Simple Reacher
|
||||
register(
|
||||
id='SimpleReacher-v0',
|
||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 2,
|
||||
@ -72,7 +72,7 @@ register(
|
||||
|
||||
register(
|
||||
id='LongSimpleReacher-v0',
|
||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -83,7 +83,7 @@ register(
|
||||
|
||||
register(
|
||||
id='ViaPointReacher-v0',
|
||||
entry_point='alr_envs.alr.classic_control:ViaPointReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:ViaPointReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -95,7 +95,7 @@ register(
|
||||
## Hole Reacher
|
||||
register(
|
||||
id='HoleReacher-v0',
|
||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -115,7 +115,7 @@ register(
|
||||
for _dims in [5, 7]:
|
||||
register(
|
||||
id=f'Reacher{_dims}d-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": _dims,
|
||||
@ -124,7 +124,7 @@ for _dims in [5, 7]:
|
||||
|
||||
register(
|
||||
id=f'Reacher{_dims}dSparse-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"sparse": True,
|
||||
@ -134,7 +134,7 @@ for _dims in [5, 7]:
|
||||
|
||||
register(
|
||||
id='HopperJumpSparse-v0',
|
||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||
kwargs={
|
||||
"sparse": True,
|
||||
@ -143,7 +143,7 @@ register(
|
||||
|
||||
register(
|
||||
id='HopperJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||
kwargs={
|
||||
"sparse": False,
|
||||
@ -155,43 +155,43 @@ register(
|
||||
|
||||
register(
|
||||
id='ALRAntJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRHalfCheetahJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||
)
|
||||
|
||||
register(
|
||||
id='HopperJumpOnBox-v0',
|
||||
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
||||
entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRHopperThrow-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRHopperThrowInBasket-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRWalker2DJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||
)
|
||||
|
||||
register(
|
||||
id='BeerPong-v0',
|
||||
entry_point='alr_envs.alr.mujoco:BeerPongEnv',
|
||||
entry_point='alr_envs.envs.mujoco:BeerPongEnv',
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
@ -199,14 +199,14 @@ register(
|
||||
# only one time step, i.e. we simulate until the end of th episode
|
||||
register(
|
||||
id='BeerPongStepBased-v0',
|
||||
entry_point='alr_envs.alr.mujoco:BeerPongEnvStepBasedEpisodicReward',
|
||||
entry_point='alr_envs.envs.mujoco:BeerPongEnvStepBasedEpisodicReward',
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
# Beerpong with episodic reward, but fixed release time step
|
||||
register(
|
||||
id='BeerPongFixedRelease-v0',
|
||||
entry_point='alr_envs.alr.mujoco:BeerPongEnvFixedReleaseStep',
|
||||
entry_point='alr_envs.envs.mujoco:BeerPongEnvFixedReleaseStep',
|
||||
max_episode_steps=300,
|
||||
)
|
||||
|
||||
@ -229,7 +229,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_simple_reacher_dmp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
@ -242,7 +242,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_simple_reacher_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# Viapoint reacher
|
||||
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
@ -257,7 +257,7 @@ register(
|
||||
# max_episode_steps=1,
|
||||
kwargs=kwargs_dict_via_point_reacher_dmp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
|
||||
|
||||
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
|
||||
@ -268,7 +268,7 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_via_point_reacher_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
||||
|
||||
## Hole Reacher
|
||||
_versions = ["HoleReacher-v0"]
|
||||
@ -288,7 +288,7 @@ for _v in _versions:
|
||||
# max_episode_steps=1,
|
||||
kwargs=kwargs_dict_hole_reacher_dmp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
@ -301,7 +301,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_hole_reacher_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
## ReacherNd
|
||||
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
|
||||
@ -320,7 +320,7 @@ for _v in _versions:
|
||||
# max_episode_steps=1,
|
||||
kwargs=kwargs_dict_reacherNd_dmp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
@ -333,7 +333,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_alr_reacher_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
########################################################################################################################
|
||||
## Beerpong ProMP
|
||||
@ -354,7 +354,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bp_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
### BP with Fixed release
|
||||
_versions = ["BeerPongStepBased-v0", "BeerPongFixedRelease-v0"]
|
||||
@ -374,7 +374,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bp_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
########################################################################################################################
|
||||
|
||||
## Table Tennis needs to be fixed according to Zhou's implementation
|
||||
@ -395,7 +395,7 @@ for _v in _versions:
|
||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
# kwargs=kwargs_dict_ant_jump_promp
|
||||
# )
|
||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
#
|
||||
# ########################################################################################################################
|
||||
#
|
||||
@ -412,7 +412,7 @@ for _v in _versions:
|
||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
# kwargs=kwargs_dict_halfcheetah_jump_promp
|
||||
# )
|
||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
#
|
||||
# ########################################################################################################################
|
||||
|
||||
@ -433,7 +433,7 @@ for _v in _versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_hopper_jump_promp
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ########################################################################################################################
|
||||
#
|
||||
@ -451,13 +451,13 @@ for _v in _versions:
|
||||
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||
# kwargs=kwargs_dict_walker2d_jump_promp
|
||||
# )
|
||||
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
# ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
### Depricated, we will not provide non random starts anymore
|
||||
"""
|
||||
register(
|
||||
id='SimpleReacher-v1',
|
||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 2,
|
||||
@ -467,7 +467,7 @@ register(
|
||||
|
||||
register(
|
||||
id='LongSimpleReacher-v1',
|
||||
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -476,7 +476,7 @@ register(
|
||||
)
|
||||
register(
|
||||
id='HoleReacher-v1',
|
||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -491,7 +491,7 @@ register(
|
||||
)
|
||||
register(
|
||||
id='HoleReacher-v2',
|
||||
entry_point='alr_envs.alr.classic_control:HoleReacherEnv',
|
||||
entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -508,7 +508,7 @@ register(
|
||||
# CtxtFree are v0, Contextual are v1
|
||||
register(
|
||||
id='ALRAntJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
|
||||
@ -518,7 +518,7 @@ register(
|
||||
# CtxtFree are v0, Contextual are v1
|
||||
register(
|
||||
id='ALRHalfCheetahJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||
@ -527,7 +527,7 @@ register(
|
||||
)
|
||||
register(
|
||||
id='ALRHopperJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:HopperJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||
@ -545,7 +545,7 @@ for i in _vs:
|
||||
_env_id = f'ALRReacher{i}-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"steps_before_reward": 0,
|
||||
@ -558,7 +558,7 @@ for i in _vs:
|
||||
_env_id = f'ALRReacherSparse{i}-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.alr.mujoco:ReacherEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"steps_before_reward": 200,
|
||||
@ -617,7 +617,7 @@ for i in _vs:
|
||||
|
||||
register(
|
||||
id='ALRHopperJumpOnBox-v0',
|
||||
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
||||
entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||
@ -626,7 +626,7 @@ for i in _vs:
|
||||
)
|
||||
register(
|
||||
id='ALRHopperThrow-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||
@ -635,7 +635,7 @@ for i in _vs:
|
||||
)
|
||||
register(
|
||||
id='ALRHopperThrowInBasket-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||
@ -644,7 +644,7 @@ for i in _vs:
|
||||
)
|
||||
register(
|
||||
id='ALRWalker2DJump-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||
kwargs={
|
||||
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
||||
@ -652,13 +652,13 @@ for i in _vs:
|
||||
}
|
||||
)
|
||||
register(id='TableTennis2DCtxt-v1',
|
||||
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||
entry_point='alr_envs.envs.mujoco:TTEnvGym',
|
||||
max_episode_steps=MAX_EPISODE_STEPS,
|
||||
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
||||
|
||||
register(
|
||||
id='ALRBeerPong-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||
entry_point='alr_envs.envs.mujoco:ALRBeerBongEnv',
|
||||
max_episode_steps=300,
|
||||
kwargs={
|
||||
"rndm_goal": False,
|
@ -7,7 +7,7 @@ from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gym.utils import seeding
|
||||
|
||||
from alr_envs.alr.classic_control.utils import intersect
|
||||
from alr_envs.envs.classic_control.utils import intersect
|
||||
|
||||
|
||||
class BaseReacherEnv(gym.Env, ABC):
|
@ -2,7 +2,7 @@ from abc import ABC
|
||||
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
|
||||
|
||||
class BaseReacherDirectEnv(BaseReacherEnv, ABC):
|
@ -2,7 +2,7 @@ from abc import ABC
|
||||
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
|
||||
|
||||
class BaseReacherTorqueEnv(BaseReacherEnv, ABC):
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from gym.core import ObsType
|
||||
from matplotlib import patches
|
||||
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
|
||||
|
||||
class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
@ -41,13 +41,13 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
|
||||
if rew_fct == "simple":
|
||||
from alr_envs.alr.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
|
||||
from alr_envs.envs.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
|
||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
||||
elif rew_fct == "vel_acc":
|
||||
from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward
|
||||
from alr_envs.envs.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward
|
||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
|
||||
elif rew_fct == "unbounded":
|
||||
from alr_envs.alr.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward
|
||||
from alr_envs.envs.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward
|
||||
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision)
|
||||
else:
|
||||
raise ValueError("Unknown reward function {}".format(rew_fct))
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
|
||||
from alr_envs.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
|
||||
|
||||
|
||||
class SimpleReacherEnv(BaseReacherTorqueEnv):
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from gym.core import ObsType
|
||||
from gym.utils import seeding
|
||||
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
|
||||
|
||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
@ -55,7 +55,7 @@ class AntJumpEnv(AntEnv):
|
||||
|
||||
costs = ctrl_cost + contact_cost
|
||||
|
||||
done = height < 0.3 # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
||||
done = bool(height < 0.3) # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
||||
|
||||
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done:
|
||||
# -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled
|
||||
@ -84,8 +84,8 @@ class AntJumpEnv(AntEnv):
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
self.current_step = 0
|
||||
self.max_height = 0
|
||||
self.goal = self.np_random.uniform(1.0, 2.5,
|
||||
1) # goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
|
||||
# goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
|
||||
self.goal = self.np_random.uniform(1.0, 2.5, 1)
|
||||
return super().reset()
|
||||
|
||||
# reset_model had to be implemented in every env to make it deterministic
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
|
||||
from alr_envs.alr.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
|
||||
from alr_envs.envs.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
|
||||
|
||||
|
||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user